Date: Fri, 15 Sep 2023 14:18:22 +0800
Subject: [PATCH 08/58] Optimized some syntax errors in the documentation and
code under applications/ (#4127)
Co-authored-by: flybird11111 <1829166702@qq.com>
---
applications/Chat/README.md | 6 ++----
applications/Chat/coati/experience_maker/base.py | 2 +-
applications/Chat/coati/models/lora.py | 2 +-
applications/Chat/coati/ray/detached_replay_buffer.py | 2 +-
applications/Chat/coati/ray/utils.py | 2 +-
applications/Chat/evaluate/README.md | 2 +-
applications/Chat/evaluate/gpt_evaluate.py | 8 ++++----
applications/Chat/examples/community/peft/README.md | 2 +-
8 files changed, 12 insertions(+), 14 deletions(-)
diff --git a/applications/Chat/README.md b/applications/Chat/README.md
index 5a1187ab503d..59e2c4548365 100644
--- a/applications/Chat/README.md
+++ b/applications/Chat/README.md
@@ -200,7 +200,6 @@ We provide an online inference server and a benchmark. We aim to run inference o
We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference.
Online inference server scripts can help you deploy your own services.
-
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## Coati7B examples
@@ -428,7 +427,7 @@ Thanks so much to all of our amazing contributors!
-- An open-source low cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
+- An open-source low-cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
@@ -469,8 +468,7 @@ Coati is developed by ColossalAI Team:
- [ofey404](https://github.com/ofey404)
- [Wenhao Chen](https://github.com/CWHer)
-The Phd student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
-
+The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
- [Xue Fuzhao](https://github.com/XueFuzhao)
diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py
index ff75852576c8..b4646f282f0c 100644
--- a/applications/Chat/coati/experience_maker/base.py
+++ b/applications/Chat/coati/experience_maker/base.py
@@ -10,7 +10,7 @@
@dataclass
class Experience:
"""Experience is a batch of data.
- These data should have the the sequence length and number of actions.
+ These data should have the sequence length and number of actions.
Left padding for sequences is applied.
Shapes of each tensor:
diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py
index 546f675d7d37..f1597da540a7 100644
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/Chat/coati/models/lora.py
@@ -48,7 +48,7 @@ def __init__(
def reset_parameters(self):
if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
+ # Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py
index 7b9df2ee139b..e04bf5ccb881 100644
--- a/applications/Chat/coati/ray/detached_replay_buffer.py
+++ b/applications/Chat/coati/ray/detached_replay_buffer.py
@@ -16,7 +16,7 @@
class DetachedReplayBuffer:
'''
Detached replay buffer. Share Experience across workers on the same node.
- Therefore a trainer node is expected to have only one instance.
+ Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
Args:
diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py
index 761186b95ee5..391ffe7a91a9 100644
--- a/applications/Chat/coati/ray/utils.py
+++ b/applications/Chat/coati/ray/utils.py
@@ -116,7 +116,7 @@ def get_model_numel(model: nn.Module) -> int:
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
target_receivers = []
if num_senders <= num_receivers or allow_idle_sender:
- # a sender will send data to one or more than one receivers
+ # a sender will send data to one or more receivers
# a receiver only has one sender
for i in range(num_receivers):
if i % num_senders == sender_idx:
diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md
index 68b03be16a30..0a97ae72f9d0 100644
--- a/applications/Chat/evaluate/README.md
+++ b/applications/Chat/evaluate/README.md
@@ -348,7 +348,7 @@ For example, if you want to add a new metric `persuasiveness` into category `bra
How can I add a new UniEval evaluation metric?
-For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown and you may need some experiments to test whether the model is capable of evaluating this metric.
+For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown, and you may need some experiments to test whether the model is capable of evaluating this metric.
```python
if task == 'data2text':
diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/Chat/evaluate/gpt_evaluate.py
index f8cfb8d0f7e5..6fcbe63d0253 100644
--- a/applications/Chat/evaluate/gpt_evaluate.py
+++ b/applications/Chat/evaluate/gpt_evaluate.py
@@ -576,7 +576,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float:
for key, value in logprobs.items():
# Sometimes the key will be one byte of a unicode character which takes the form of "bytes:\\xe7".
- # It is meaningless and thus we don't calculate probability.
+ # It is meaningless, and thus we don't calculate probability.
if "bytes" in key:
continue
# results[0] is the score which corresponds to the key(predicted token).
@@ -621,7 +621,7 @@ def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[st
Args:
model_name: name of the model for saving evaluation results.
- gpt_evaluation_results: evaluations results for all of the model answers.
+ gpt_evaluation_results: evaluations results for all the model answers.
save_path: path to save GPT evaluation statistics.
"""
@@ -641,7 +641,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
Args:
model_name: name of the model for saving statistics.
- evaluations: evaluations for all of the model answers.
+ evaluations: evaluations for all the model answers.
save_path: path to save GPT evaluation statistics.
"""
@@ -663,7 +663,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
for evaluation in data:
for metric in metrics:
if evaluation["evaluation"][metric] == {}:
- # This means after 3 retries, the server still returns an error and we set the score to 0.
+ # This means after 3 retries, the server still returns an error, and we set the score to 0.
scores[metric].append(0)
elif evaluation["evaluation"][metric]["logprobs"] is not None:
scores[metric].append(
diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md
index 8b2edc48cd99..ada3a16296af 100644
--- a/applications/Chat/examples/community/peft/README.md
+++ b/applications/Chat/examples/community/peft/README.md
@@ -20,7 +20,7 @@ pip install .
For SFT training, just call train_peft_sft.py
-Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
+Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
From 46162632e5dc8c0d7f6928b85d55b4d557615a8e Mon Sep 17 00:00:00 2001
From: flybird11111 <1829166702@qq.com>
Date: Fri, 15 Sep 2023 14:32:04 +0800
Subject: [PATCH 09/58] [shardformer] update pipeline parallel document (#4725)
* [shardformer] update pipeline parallel document
* [shardformer] update pipeline parallel document
* [shardformer] update pipeline parallel document
* [shardformer] update pipeline parallel document
* [shardformer] update pipeline parallel document
* [shardformer] update pipeline parallel document
* [shardformer] update pipeline parallel document
* [shardformer] update pipeline parallel document
---
docs/source/en/features/pipeline_parallel.md | 222 +++++++++++-------
.../zh-Hans/features/pipeline_parallel.md | 218 ++++++++++-------
2 files changed, 276 insertions(+), 164 deletions(-)
diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md
index 8b5f228a9e5e..cb19f9815bf2 100644
--- a/docs/source/en/features/pipeline_parallel.md
+++ b/docs/source/en/features/pipeline_parallel.md
@@ -1,14 +1,15 @@
# Pipeline Parallel
-Author: Guangyang Lu, Hongxin Liu, Yongbin Li
+Author: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang
**Prerequisite**
-- [Define Your Configuration](../basics/define_your_config.md)
-- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
-- [Configure Parallelization](../basics/configure_parallelization.md)
+- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
+- [Use Booster to Training](../basics/booster_api.md)
+- [Shardformer](../features/shardformer.md)
+- [Plugin of Booster](../basics/booster_plugins.md)
**Example Code**
-- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel)
+- [Fine-tune Bert with pipeline](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py)
**Related Paper**
- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
@@ -17,7 +18,7 @@ Author: Guangyang Lu, Hongxin Liu, Yongbin Li
## Quick introduction
-In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example.
+In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use bert model and glue dataset as example.
## Table Of Content
@@ -25,7 +26,7 @@ In this tutorial we will cover:
1. Introduction of 1F1B pipeline.
2. Usage of non-interleaved and interleaved schedule.
-3. Training ResNet with pipeline.
+3. Finetune Bert with pipeline.
## Introduction of 1F1B pipeline
@@ -60,101 +61,158 @@ In this schedule, each device can perform computation for multiple subsets of la
This mode is both memory-efficient and time-efficient.
-## Usage of non-interleaved and interleaved schedule
+## Colossal-AI's Implementation
-In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`).
+In Colossal-AI, pipeline parallelism relies on the `scheduler` and [`Shardformer`](../features/shardformer.md). We provide both non-interleaved (`OneForwardOneBackwardSchedule`) and interleaved (`InterleavedSchedule`) schedules. While `Shardformer` implements layer splitting for models and replaces the `forward` function of the model to make it compatible with the scheduler.
-You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you.
+In Colossal-AI, the `HybridParallelPlugin` encapsulates pipeline execution strategies. It manages pipeline parallel communication groups and a scheduler. When boosting the model with this plugin, the model's layers are split by calling the `shardformer.optimize` function, and then `execute_pipeline` is called to execute the model in segments using `OneForwardOneBackwardSchedule` which is default scheduler used in `HybridParallelPlugin`, and `InterleavedSchedule` will be integrated later.
-## Training ResNet with pipeline
+You can customize your parallel strategy by setting parameters for the `HybridParallelPlugin`.
-Let's build the `ResNet` model first with Colossal PipelinableContext:
+For more usage details, please refer to the [documentation](../basics/booster_plugins.md) for `HybridParallelPlugin`.
+
+## Fine-tune Bert with pipeline
+
+First, we define the necessary training components, including model, dataloader, optimizer, lr_scheduler, criterion:
```python
-import os
-from typing import Callable, List, Optional, Type, Union
+import argparse
+from typing import Callable, List, Union
+
import torch
import torch.nn as nn
+from data import GLUEDataBuilder
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import (
+ AlbertForSequenceClassification,
+ AutoConfig,
+ BertForSequenceClassification,
+ get_linear_schedule_with_warmup,
+)
+
import colossalai
-import colossalai.nn as col_nn
+from colossalai.booster import Booster
+from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.optimizer import HybridAdam
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.legacy.trainer import Trainer, hooks
-from colossalai.utils import MultiTimer, get_dataloader
-from colossalai.context import ParallelMode
-from colossalai.pipeline.pipelinable import PipelinableContext
+# Define some config
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+
+coordinator = DistCoordinator()
+
+def move_to_cuda(batch):
+ return {k: v.cuda() for k, v in batch.items()}
+
+
+# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'.
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+# Define optimizer
+lr = LEARNING_RATE
+no_decay = ["bias", "LayerNorm.weight"]
+optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+]
-from titans.dataloader.cifar10 import build_cifar
-from torchvision.models import resnet50
-from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
+optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
-# Define some config
-BATCH_SIZE = 64
-NUM_EPOCHS = 2
-NUM_CHUNKS = 1
-CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
-
-# Train
-disable_existing_loggers()
-parser = colossalai.get_default_parser()
-args = parser.parse_args()
-colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
-logger = get_dist_logger()
-pipelinable = PipelinableContext()
-
-# build model
-with pipelinable:
- model = resnet50()
-```
-Define an execution sequence.
-```python
-exec_seq = [
- 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
- (lambda x: torch.flatten(x, 1), "behind"), 'fc'
-]
-pipelinable.to_layer_list(exec_seq)
+# Define lr_scheduler
+total_steps = len(train_dataloader) * NUM_EPOCHS
+num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+)
+
+
+# Define Bert model
+model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda()
+
+# Define a dataloader
+data_builder = GLUEDataBuilder(model_name,
+ plugin,
+ args.task,
+ train_batch_size=BATCH_SIZE,
+ eval_batch_size=BATCH_SIZE)
+train_dataloader = data_builder.train_dataloader()
```
-Partition the model into pipeline.
+Define a booster with the `HybridParallelPlugin`.
```python
-model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
+plugin = HybridParallelPlugin(tp_size=1,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ zero_stage=1,
+ precision='fp16',
+ initial_scale=1)
+booster = Booster(plugin=plugin)
```
-In this tutorial, we use `Trainer` to train `ResNet`:
+Boost these train componts with the booster created.
```python
-# build criterion
-criterion = nn.CrossEntropyLoss()
-
-# optimizer
-optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
-# build dataloader
-root = os.environ.get('DATA', './data')
-train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)
-
-lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
-engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
- train_dataloader, test_dataloader,
- lr_scheduler)
-timer = MultiTimer()
+model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
+ optimizer,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
+```
-trainer = Trainer(engine=engine, timer=timer, logger=logger)
+Train the model at last.
-hook_list = [
- hooks.LossHook(),
- hooks.AccuracyHook(col_nn.metric.Accuracy()),
- hooks.LogMetricByEpochHook(logger),
- hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
-]
-
-trainer.fit(train_dataloader=train_dataloader,
- epochs=NUM_EPOCHS,
- test_dataloader=test_dataloader,
- test_interval=1,
- hooks=hook_list,
- display_progress=True)
+```python
+# Define a train function
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
+ train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+
+ is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
+ total_step = len(train_dataloader)
+
+ model.train()
+ optimizer.zero_grad()
+ # convert train_dataloader to a iterator
+ train_dataloader_iter = iter(train_dataloader)
+ with tqdm(range(total_step),
+ desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
+ disable=not (is_pp_last_stage)) as pbar:
+ # Forward pass
+ for _ in pbar:
+ outputs = booster.execute_pipeline(train_dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+# Train model
+for epoch in range(NUM_EPOCHS):
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
```
-We use `2` pipeline stages and the batch will be split into `4` micro batches.
+We use `2` pipeline stages and the micro batches is 1. (these parameters can be configured to an appropriate value)
diff --git a/docs/source/zh-Hans/features/pipeline_parallel.md b/docs/source/zh-Hans/features/pipeline_parallel.md
index 1497dc399f6c..e688020556d8 100644
--- a/docs/source/zh-Hans/features/pipeline_parallel.md
+++ b/docs/source/zh-Hans/features/pipeline_parallel.md
@@ -1,14 +1,15 @@
# 流水并行
-作者: Guangyang Lu, Hongxin Liu, Yongbin Li
+作者: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang
**前置教程**
-- [定义配置文件](../basics/define_your_config.md)
-- [在训练中使用Engine和Trainer](../basics/engine_trainer.md)
-- [并行配置](../basics/configure_parallelization.md)
+- [并行技术](../concepts/paradigms_of_parallelism.md)
+- [Booster API](../basics/booster_api.md)
+- [Shardformer](../features/shardformer.md)
+- [Booster 插件](../basics/booster_plugins.md)
**示例代码**
-- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel)
+- [使用pipeline并行策略微调Bert](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py)
**相关论文**
- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
@@ -17,7 +18,7 @@
## 快速预览
-在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 ResNet 和 CIFAR 为例.
+在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 Bert 和 Glue数据集 为例.
## 目录
@@ -25,7 +26,7 @@
1. 介绍 1F1B 流水线;
2. 使用非交错和交错 schedule;
-3. 使用流水线训练 ResNet。
+3. 使用流水线微调 Bert
## 认识 1F1B 流水线
@@ -59,101 +60,154 @@
这种模式既节省内存又节省时间。
-## 使用schedule
+## Colossal-AI中的实现
-在 Colossal-AI 中, 我们提供非交错(`PipelineSchedule`) 和交错(`InterleavedPipelineSchedule`)schedule。
+在 Colossal-AI 中,流水线并行依赖于 `scheduler` 和 `Shardformer`。我们提供了非交错的(`OneForwardOneBackwardSchedule`)和交错的(`InterleavedSchedule`)两种调度方式。而 Shardformer 实现了对模型的层分割,并替换了模型的 `forward` 函数,使其与调度器兼容。
-你只需要在配置文件中,设置 `NUM_MICRO_BATCHES` 并在你想使用交错schedule的时候,设置 `NUM_CHUNKS`。 如果你确定性地知道每个管道阶段的输出张量的形状,而且形状都是一样的,你可以设置 `tensor_shape` 以进一步减少通信。否则,你可以忽略 `tensor_shape` , 形状将在管道阶段之间自动交换。 我们将会根据用户提供的配置文件,生成一个合适schedule来支持用户的流水并行训练。
+在 Colossal-AI 中,`HybridParallelPlugin` 封装了流水线执行策略。它管理流水线并行通信组和一个 `scheduler`。当使用此插件增强模型时,模型的层将通过调用 `shardformer.optimize` 函数进行分割,然后调用 `execute_pipeline` 使用 `scheduler` 来分别执行模型的各个部分。 `HybridParallelPlugin`暂时只支持`OneForwardOneBackwardSchedule`, `InterleavedSchedule`将会在不久后支持。
-## 使用流水线训练 ResNet
+您可以通过设置 `HybridParallelPlugin` 的参数来自定义您的并行策略。更多使用细节请参考`HybridParallelPlugin`的[使用文档](../basics/booster_plugins.md)。
-我们首先用Colossal PipelinableContext方式建立 `ResNet` 模型:
+## 使用流水线微调 Bert模型
+
+首先我们定义好需要的训练组件,包括`model`, `dataloader`, `optimizer`, `lr_scheduler`, `criterion` 等:
```python
-import os
-from typing import Callable, List, Optional, Type, Union
+import argparse
+from typing import Callable, List, Union
+
import torch
import torch.nn as nn
+from data import GLUEDataBuilder
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import (
+ AlbertForSequenceClassification,
+ AutoConfig,
+ BertForSequenceClassification,
+ get_linear_schedule_with_warmup,
+)
+
import colossalai
-import colossalai.nn as col_nn
+from colossalai.booster import Booster
+from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.optimizer import HybridAdam
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.legacy.trainer import Trainer, hooks
-from colossalai.utils import MultiTimer, get_dataloader
-from colossalai.context import ParallelMode
-from colossalai.pipeline.pipelinable import PipelinableContext
+# Define some config
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+
+coordinator = DistCoordinator()
+
+def move_to_cuda(batch):
+ return {k: v.cuda() for k, v in batch.items()}
+
+# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'.
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+# Define optimizer
+lr = LEARNING_RATE
+no_decay = ["bias", "LayerNorm.weight"]
+optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+]
-from titans.dataloader.cifar10 import build_cifar
-from torchvision.models import resnet50
-from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
+optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
-# Define some config
-BATCH_SIZE = 64
-NUM_EPOCHS = 2
-NUM_CHUNKS = 1
-CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
-
-# Train
-disable_existing_loggers()
-parser = colossalai.get_default_parser()
-args = parser.parse_args()
-colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
-logger = get_dist_logger()
-pipelinable = PipelinableContext()
-
-# build model
-with pipelinable:
- model = resnet50()
+
+# Define lr_scheduler
+total_steps = len(train_dataloader) * NUM_EPOCHS
+num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+)
+
+
+# Define Bert model
+model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda()
+
+# Define a dataloader
+data_builder = GLUEDataBuilder(model_name,
+ plugin,
+ args.task,
+ train_batch_size=BATCH_SIZE,
+ eval_batch_size=BATCH_SIZE)
+train_dataloader = data_builder.train_dataloader()
```
-给定切分顺序,module直接给出name,部分函数需要手动添加。
+使用`HybridParallelPlugin`初始化一个booster.
```python
-exec_seq = [
- 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
- (lambda x: torch.flatten(x, 1), "behind"), 'fc'
-]
-pipelinable.to_layer_list(exec_seq)
+plugin = HybridParallelPlugin(tp_size=1,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ zero_stage=1,
+ precision='fp16',
+ initial_scale=1)
+booster = Booster(plugin=plugin)
```
-将模型切分成流水线阶段。
+使用`booster`将优化特性注入到训练组件中。
```python
-model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
+model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
+ optimizer,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
```
-我们使用`Trainer`训练`ResNet`:
+最后训练模型
```python
-# build criterion
-criterion = nn.CrossEntropyLoss()
-
-# optimizer
-optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
-# build dataloader
-root = os.environ.get('DATA', './data')
-train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)
-
-lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
-engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
- train_dataloader, test_dataloader,
- lr_scheduler)
-timer = MultiTimer()
-
-trainer = Trainer(engine=engine, timer=timer, logger=logger)
-
-hook_list = [
- hooks.LossHook(),
- hooks.AccuracyHook(col_nn.metric.Accuracy()),
- hooks.LogMetricByEpochHook(logger),
- hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
-]
-
-trainer.fit(train_dataloader=train_dataloader,
- epochs=NUM_EPOCHS,
- test_dataloader=test_dataloader,
- test_interval=1,
- hooks=hook_list,
- display_progress=True)
+# Define a train function
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
+ train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+
+ is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
+ total_step = len(train_dataloader)
+
+ model.train()
+ optimizer.zero_grad()
+ # convert train_dataloader to a iterator
+ train_dataloader_iter = iter(train_dataloader)
+ with tqdm(range(total_step),
+ desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
+ disable=not (is_pp_last_stage)) as pbar:
+ # Forward pass
+ for _ in pbar:
+ outputs = booster.execute_pipeline(train_dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+# Train model
+for epoch in range(NUM_EPOCHS):
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
```
-我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。
+我们使用 `2` 个流水段,并且 batch 将被切分为 `1` 个 micro batches。(这些参数都可根据实际情况设置为合适的值)
From cd4e61d149db3b98435cf6c90a389d8d1dff21e6 Mon Sep 17 00:00:00 2001
From: Pengtai Xu
Date: Fri, 15 Sep 2023 15:52:18 +0800
Subject: [PATCH 10/58] [legacy] remove deterministic data loader test
---
.../test_deterministic_dataloader.py | 73 -------------------
1 file changed, 73 deletions(-)
delete mode 100644 tests/test_data/test_deterministic_dataloader.py
diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py
deleted file mode 100644
index 283b5cc35279..000000000000
--- a/tests/test_data/test_deterministic_dataloader.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import os
-from pathlib import Path
-
-import pytest
-import torch
-import torch.distributed as dist
-from torchvision import datasets, transforms
-
-import colossalai
-from colossalai.context import Config, ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_dataloader
-
-CONFIG = Config(
- dict(
- train_data=dict(
- dataset=dict(
- type='CIFAR10',
- root=Path(os.environ['DATA']),
- train=True,
- download=True,
- ),
- dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
- ),
- parallel=dict(
- pipeline=dict(size=1),
- tensor=dict(size=1, mode=None),
- ),
- seed=1024,
- ))
-
-
-def run_data_sampler(rank, world_size, port):
- dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
- colossalai.launch(**dist_args)
-
- # build dataset
- transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)]
- transform_pipeline = transforms.Compose(transform_pipeline)
- dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
-
- # build dataloader
- dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False)
-
- data_iter = iter(dataloader)
- img, label = data_iter.next()
- img = img[0]
-
- if gpc.get_local_rank(ParallelMode.DATA) != 0:
- img_to_compare = img.clone()
- else:
- img_to_compare = img
- dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))
-
- if gpc.get_local_rank(ParallelMode.DATA) != 0:
- # this is without sampler
- # this should be false if data parallel sampler to given to the dataloader
- assert torch.equal(img,
- img_to_compare), 'Same image was distributed across ranks and expected it to be the same'
- torch.cuda.empty_cache()
-
-
-@rerun_if_address_is_in_use()
-def test_data_sampler():
- spawn(run_data_sampler, 4)
-
-
-if __name__ == '__main__':
- test_data_sampler()
From 6a03c933a0ef43090c8add9303898287b1482d74 Mon Sep 17 00:00:00 2001
From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Date: Fri, 15 Sep 2023 16:09:32 +0800
Subject: [PATCH 11/58] [shardformer] update seq parallel document (#4730)
* update doc of seq parallel
* fix typo
---
docs/source/en/features/shardformer.md | 17 +++++++++++++++--
docs/source/zh-Hans/features/shardformer.md | 16 +++++++++++++++-
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md
index 10e03e963a95..ca23f07421d1 100644
--- a/docs/source/en/features/shardformer.md
+++ b/docs/source/en/features/shardformer.md
@@ -1,6 +1,6 @@
# Shardformer
-Author: [Baizhou Zhang](https://github.com/Fridge003)
+Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer)
**Prerequisite**
- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
@@ -16,7 +16,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003)
- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691)
- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
-
+- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198)
## Introduction
@@ -74,6 +74,18 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs.
```
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
+### Sequence Parallelism
+
+Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism just use with 1D tensor parallelism to to further reduce the memory occupation of activations in computations.
+
+1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradient from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the device.
+
+2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs in sequence dimension during forward and Reduce-Scatter to splite the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to splite the output of row linear layer of tensor parallel to all devices in sequence dimension, and All-Gather to get the whole gradient during backward.
+
+3. The implementation of All-Reduce using NCCL adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared to sequence parallelism and tensor parallelism, it does not introduce additional communication overhead.
+
+4. One important thing to note is that when using sequence parallelism with 'Column Linear' of tensor parallelism,, during the backward computation of gradients, the complete input needs to be obtained. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, shape like $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, in the implementation, it is possible to overlap the gradient computation with the All-Gather communication operation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`).
+
## How Shardformer Works
Generally, Shardformer works through the following four kinds of *replacements*:
@@ -100,6 +112,7 @@ As a result, the optimizer will only compute the states corresponding to these p
All of these replacements are implemented with manually written policies and forward functions.
If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details.
+
## Supporting Information
Model/Feature Compatibility Matrix:
diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md
index e0d8df2c90c8..7de0c41c10d7 100644
--- a/docs/source/zh-Hans/features/shardformer.md
+++ b/docs/source/zh-Hans/features/shardformer.md
@@ -1,6 +1,6 @@
# Shardformer
-Author: [Baizhou Zhang](https://github.com/Fridge003)
+Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer)
**预备知识**
- [并行技术](../concepts/paradigms_of_parallelism.md)
@@ -16,6 +16,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003)
- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691)
- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
+- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198)
## 简介
@@ -65,6 +66,19 @@ Shardformer的配置由类`ShardConfig`的参数控制:
并且使用这些导入的类初始化模型。
+### 序列并行 Sequence Parallelism
+
+在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。
+
+1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。
+
+2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。
+
+3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。
+
+4. 需要注意的一点是,在张量并行的 “Column Linear” 中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。
+
+
## Shardformer的工作原理
通常来说,Shardformer通过以下四种“替换”进行工作:
From 608cffaed3821bacdfce7c44cdf09e6cd38d32c2 Mon Sep 17 00:00:00 2001
From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Date: Fri, 15 Sep 2023 17:12:46 +0800
Subject: [PATCH 12/58] [example] add gpt2 HybridParallelPlugin example (#4653)
* add gpt2 HybridParallelPlugin example
* update readme and testci
* update test ci
* fix test_ci bug
* update requirements
* add requirements
* update requirements
* add requirement
* rename file
---
examples/language/gpt/README.md | 10 +
.../language/gpt/hybridparallelism/data.py | 127 ++++++++
.../gpt/hybridparallelism/finetune.py | 299 ++++++++++++++++++
.../language/gpt/hybridparallelism/run.sh | 5 +
examples/language/gpt/requirements.txt | 5 +
examples/language/gpt/test_ci.sh | 3 +
6 files changed, 449 insertions(+)
create mode 100644 examples/language/gpt/hybridparallelism/data.py
create mode 100644 examples/language/gpt/hybridparallelism/finetune.py
create mode 100644 examples/language/gpt/hybridparallelism/run.sh
diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md
index 47d24a4d69cb..03679e66404a 100644
--- a/examples/language/gpt/README.md
+++ b/examples/language/gpt/README.md
@@ -65,6 +65,16 @@ Titans provides a customized GPT model, which uses distributed operators as buil
In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP.
You can switch parallel strategies using a config file.
+### Hybridparallelism
+
+Hybridparallelism provides a user friendly plugin to set multiple parallelism method for training and inference. In [./hybridparallelism], we provide a n example to finetune gpt2 using Hybridparallelism.
+
+Quick run
+```bash
+cd ./hybridparallelism
+bash run.sh
+```
+
## Performance
Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e.
diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py
new file mode 100644
index 000000000000..981cedcca8c2
--- /dev/null
+++ b/examples/language/gpt/hybridparallelism/data.py
@@ -0,0 +1,127 @@
+import datasets
+from transformers import AutoTokenizer, PreTrainedTokenizer
+
+from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
+
+
+class GLUEDataBuilder:
+
+ task_text_field_map = {
+ "cola": ["sentence"],
+ "sst2": ["sentence"],
+ "mrpc": ["sentence1", "sentence2"],
+ "qqp": ["question1", "question2"],
+ "stsb": ["sentence1", "sentence2"],
+ "mnli": ["premise", "hypothesis"],
+ "qnli": ["question", "sentence"],
+ "rte": ["sentence1", "sentence2"],
+ "wnli": ["sentence1", "sentence2"],
+ "ax": ["premise", "hypothesis"],
+ }
+
+ glue_task_num_labels = {
+ "cola": 2,
+ "sst2": 2,
+ "mrpc": 2,
+ "qqp": 2,
+ "stsb": 1,
+ "mnli": 3,
+ "qnli": 2,
+ "rte": 2,
+ "wnli": 2,
+ "ax": 3,
+ }
+
+ loader_columns = [
+ "datasets_idx",
+ "input_ids",
+ "token_type_ids",
+ "attention_mask",
+ "start_positions",
+ "end_positions",
+ "labels",
+ ]
+
+ def __init__(
+ self,
+ model_name_or_path: str,
+ plugin: DPPluginBase,
+ task_name: str = "mrpc",
+ max_seq_length: int = 128,
+ train_batch_size: int = 32,
+ eval_batch_size: int = 32,
+ **kwargs,
+ ):
+ super().__init__()
+ self.model_name_or_path = model_name_or_path
+ self.task_name = task_name
+ self.max_seq_length = max_seq_length
+ self.train_batch_size = train_batch_size
+ self.eval_batch_size = eval_batch_size
+ self.plugin = plugin
+
+ self.text_fields = self.task_text_field_map[task_name]
+ self.num_labels = self.glue_task_num_labels[task_name]
+ self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
+ self.setup()
+
+ def setup(self):
+ self.dataset = datasets.load_dataset("glue", self.task_name)
+
+ for split in self.dataset.keys():
+ self.dataset[split] = self.dataset[split].map(
+ self.convert_to_features,
+ batched=True,
+ remove_columns=["label"],
+ )
+ self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
+ self.dataset[split].set_format(type="torch", columns=self.columns)
+
+ self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
+
+ def prepare_data(self):
+ datasets.load_dataset("glue", self.task_name)
+ AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
+
+ def train_dataloader(self):
+ return self.plugin.prepare_dataloader(self.dataset["train"],
+ batch_size=self.train_batch_size,
+ shuffle=True,
+ drop_last=True)
+
+ def val_dataloader(self):
+ if len(self.eval_splits) == 1:
+ return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
+ elif len(self.eval_splits) > 1:
+ return [
+ self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
+ for x in self.eval_splits
+ ]
+
+ def test_dataloader(self):
+ if len(self.eval_splits) == 1:
+ return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
+ elif len(self.eval_splits) > 1:
+ return [
+ self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
+ for x in self.eval_splits
+ ]
+
+ def convert_to_features(self, example_batch):
+
+ # Either encode single sentence or sentence pairs
+ if len(self.text_fields) > 1:
+ texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
+ else:
+ texts_or_text_pairs = example_batch[self.text_fields[0]]
+
+ # Tokenize the text/text pairs
+ features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
+ max_length=self.max_seq_length,
+ padding='max_length',
+ truncation=True)
+
+ # Rename label to labels to make it easier to pass to model forward
+ features["labels"] = example_batch["label"]
+
+ return features
diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py
new file mode 100644
index 000000000000..03e5ec91b3fe
--- /dev/null
+++ b/examples/language/gpt/hybridparallelism/finetune.py
@@ -0,0 +1,299 @@
+import argparse
+from contextlib import nullcontext
+from typing import Callable, List, Union
+
+import evaluate
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from data import GLUEDataBuilder
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+# ==============================
+# Prepare Hyperparameters
+# ==============================
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+
+output_transform_fn = lambda x: x
+criterion = lambda x: x.loss
+
+
+def move_to_cuda(batch):
+ return {k: v.cuda() for k, v in batch.items()}
+
+
+@torch.no_grad()
+def evaluate_model(
+ model: nn.Module,
+ criterion,
+ test_dataloader: Union[DataLoader, List[DataLoader]],
+ num_labels: int,
+ task_name: str,
+ eval_splits: List[str],
+ booster: Booster,
+ coordinator: DistCoordinator,
+):
+ metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
+ model.eval()
+
+ def evaluate_subset(dataloader: DataLoader):
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+
+ accum_loss = torch.zeros(1, device=get_current_device())
+ for batch in dataloader:
+ batch = move_to_cuda(batch)
+ labels = batch["labels"]
+ if use_pipeline:
+ pg_mesh = booster.plugin.pg_mesh
+ pp_group = booster.plugin.pp_group
+ current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
+ current_rank = dist.get_rank()
+ batch = iter([batch])
+ outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
+
+ if is_pp_last_stage:
+ logits = outputs["outputs"]["logits"]
+ val_loss = outputs["loss"]
+ accum_loss.add_(val_loss)
+
+ if num_labels > 1:
+ preds = torch.argmax(logits, axis=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)
+
+ metric.add_batch(predictions=preds, references=labels)
+ elif current_rank in current_pp_group_ranks:
+ object_list = [None, None]
+ dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
+
+ metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
+ accum_loss.add_(object_list[1].to(get_current_device()))
+
+ else:
+ batch = move_to_cuda(batch)
+ outputs = model(**batch)
+ val_loss, logits = outputs[:2]
+ accum_loss.add_(val_loss)
+
+ if num_labels > 1:
+ preds = torch.argmax(logits, axis=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ metric.add_batch(predictions=preds, references=labels)
+
+ results = metric.compute()
+ dist.all_reduce(accum_loss.div_(len(dataloader)))
+ if coordinator.is_master() and results is not None:
+ results['loss'] = accum_loss.item() / coordinator.world_size
+
+ return results
+
+ if isinstance(test_dataloader, DataLoader):
+ return evaluate_subset(test_dataloader)
+ else:
+ assert len(test_dataloader) == len(eval_splits)
+ final_results = {}
+ for split, sub_loader in zip(eval_splits, test_dataloader):
+ results = evaluate_subset(sub_loader)
+ final_results.update({f'{k}_{split}': v for k, v in results.items()})
+ return final_results
+
+
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
+ train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ total_step = len(train_dataloader)
+
+ model.train()
+ optimizer.zero_grad()
+ train_dataloader_iter = iter(train_dataloader)
+ with tqdm(range(total_step),
+ desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
+ disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
+ # Forward pass
+ for _ in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(train_dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+ else:
+ data = next(train_dataloader_iter)
+ data = move_to_cuda(data)
+ outputs = model(**data)
+ loss = _criterion(outputs, None)
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({'loss': loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+
+def main():
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
+ parser.add_argument('-p',
+ '--plugin',
+ type=str,
+ default='torch_ddp',
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
+ help="plugin to use")
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ default="gpt2",
+ help="only gpt2 now",
+ )
+ parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
+ parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
+ args = parser.parse_args()
+
+ if args.model_type == 'gpt2':
+ model_name = "gpt2"
+ else:
+ raise RuntimeError
+ # ==============================
+ # Launch Distributed Environment
+ # ==============================
+ colossalai.launch_from_torch(config={}, seed=42)
+ coordinator = DistCoordinator()
+
+ # local_batch_size = BATCH_SIZE // coordinator.world_size
+ lr = LEARNING_RATE * coordinator.world_size
+
+ # ==============================
+ # Instantiate Plugin and Booster
+ # ==============================
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ elif args.plugin == 'hybrid_parallel':
+
+ # modify the param accordingly for finetuning test cases
+ plugin = HybridParallelPlugin(tp_size=1,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ zero_stage=1,
+ precision='fp16',
+ initial_scale=1)
+
+ booster = Booster(plugin=plugin, **booster_kwargs)
+
+ # ==============================
+ # Prepare Dataloader
+ # ==============================
+ data_builder = GLUEDataBuilder(model_name,
+ plugin,
+ args.task,
+ train_batch_size=BATCH_SIZE,
+ eval_batch_size=BATCH_SIZE)
+ train_dataloader = data_builder.train_dataloader()
+ test_dataloader = data_builder.test_dataloader()
+
+ # ====================================
+ # Prepare model, optimizer
+ # ====================================
+ # gpt2 pretrained model
+
+ cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
+
+ if model_name == "gpt2":
+ model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
+ else:
+ raise RuntimeError
+
+ # optimizer
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
+
+ # lr scheduler
+ total_steps = len(train_dataloader) * NUM_EPOCHS
+ num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+ )
+
+ def _criterion(outputs, inputs):
+ outputs = output_transform_fn(outputs)
+ loss = criterion(outputs)
+ return loss
+
+ # ==============================
+ # Boost with ColossalAI
+ # ==============================
+ model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
+ optimizer,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
+
+ # ==============================
+ # Train model
+ # ==============================
+ for epoch in range(NUM_EPOCHS):
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
+
+ results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task,
+ data_builder.eval_splits, booster, coordinator)
+
+ if coordinator.is_master():
+ print(results)
+ if args.target_f1 is not None and 'f1' in results:
+ assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/gpt/hybridparallelism/run.sh b/examples/language/gpt/hybridparallelism/run.sh
new file mode 100644
index 000000000000..679cbbf9b1e2
--- /dev/null
+++ b/examples/language/gpt/hybridparallelism/run.sh
@@ -0,0 +1,5 @@
+# load via internet
+torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2"
+
+# load from local
+# torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2" --pretrained_path "your/path/to/pretrained_model"
diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt
index ef58bb76bfc8..1a173f228aee 100644
--- a/examples/language/gpt/requirements.txt
+++ b/examples/language/gpt/requirements.txt
@@ -1,2 +1,7 @@
transformers >= 4.23
colossalai
+evaluate
+tqdm
+scipy
+scikit-learn
+numpy
diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh
index d67c17229e71..b9e4e43a8d35 100644
--- a/examples/language/gpt/test_ci.sh
+++ b/examples/language/gpt/test_ci.sh
@@ -1,2 +1,5 @@
set -x
+pip install -r requirements.txt
+
cd gemini && bash test_ci.sh
+cd ../hybridparallelism && bash run.sh
From 451c3465fbde69695270bfd8f7ad26bebc079432 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Fri, 15 Sep 2023 17:39:10 +0800
Subject: [PATCH 13/58] [doc] polish shardformer doc (#4735)
* arrange position of chapters
* fix typos in seq parallel doc
---
docs/source/en/features/shardformer.md | 167 ++++++++++----------
docs/source/zh-Hans/features/shardformer.md | 141 ++++++++---------
2 files changed, 153 insertions(+), 155 deletions(-)
diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md
index ca23f07421d1..4abfff8a3cfa 100644
--- a/docs/source/en/features/shardformer.md
+++ b/docs/source/en/features/shardformer.md
@@ -29,90 +29,6 @@ This module aims to make parallelization hassle-free for users who are not from
Within a few lines of codes, users can turn a model into a state ready for distributed training.
Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass.
-## Usage
-
-### Shardformer Configuration
-
-The configuration of Shardformer is controlled by class `ShardConfig`:
-
-{{ autodoc:colossalai.shardformer.ShardConfig }}
-
-If you want to enable Apex Fused Layernorm, please install `apex`.
-If you want to enable the usage of flash attention, please install `flash_attn`.
-In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
-
-### Enabling Shardformer
-
-#### 1. Enabling Shardformer Through Booster (Recommended)
-
-Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
-The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
-
-More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).
-
-[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.
-
-
-#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
-
-You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
-
-[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
-is an example on how to trigger `Shardformer` through calling Shardformer APIs.
-
-
-### Precautions
-
-1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method.
-
-2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
-
-3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
- ```python
- from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
- from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
- ```
- when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
-
-### Sequence Parallelism
-
-Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism just use with 1D tensor parallelism to to further reduce the memory occupation of activations in computations.
-
-1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradient from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the device.
-
-2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs in sequence dimension during forward and Reduce-Scatter to splite the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to splite the output of row linear layer of tensor parallel to all devices in sequence dimension, and All-Gather to get the whole gradient during backward.
-
-3. The implementation of All-Reduce using NCCL adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared to sequence parallelism and tensor parallelism, it does not introduce additional communication overhead.
-
-4. One important thing to note is that when using sequence parallelism with 'Column Linear' of tensor parallelism,, during the backward computation of gradients, the complete input needs to be obtained. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, shape like $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, in the implementation, it is possible to overlap the gradient computation with the All-Gather communication operation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`).
-
-## How Shardformer Works
-
-Generally, Shardformer works through the following four kinds of *replacements*:
-
-1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
-The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters.
-Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism.
-Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
-
-2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training.
-For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`.
-
-3. Replacing the `forward` methods implemented by original Huggingface
-Transformers libraries with our customized `forward` methods.
-This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages.
-Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method.
-
-4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer).
-By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of.
-To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them.
-All other parameters are released so as to liberate memory usage.
-As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved.
-
-All of these replacements are implemented with manually written policies and forward functions.
-If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details.
-
-
## Supporting Information
Model/Feature Compatibility Matrix:
@@ -279,4 +195,87 @@ List of model families we plan to support in the near future:
The support matrix will grow larger as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project.
+## Usage
+
+### Shardformer Configuration
+
+The configuration of Shardformer is controlled by class `ShardConfig`:
+
+{{ autodoc:colossalai.shardformer.ShardConfig }}
+
+If you want to enable Apex Fused Layernorm, please install `apex`.
+If you want to enable the usage of flash attention, please install `flash_attn`.
+In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
+
+### Enabling Shardformer
+
+#### 1. Enabling Shardformer Through Booster (Recommended)
+
+Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
+The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
+
+More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).
+
+[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.
+
+
+#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
+
+You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
+
+[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
+is an example on how to trigger `Shardformer` through calling Shardformer APIs.
+
+### Precautions
+
+1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method.
+
+2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
+
+3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
+ ```python
+ from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
+ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
+ ```
+ when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
+
+## How Shardformer Works
+
+Generally, Shardformer works through the following four kinds of *replacements*:
+
+1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
+The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters.
+Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism.
+Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
+
+2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training.
+For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`.
+
+3. Replacing the `forward` methods implemented by original Huggingface
+Transformers libraries with our customized `forward` methods.
+This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages.
+Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method.
+
+4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer).
+By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of.
+To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them.
+All other parameters are released so as to liberate memory usage.
+As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved.
+
+All of these replacements are implemented with manually written policies and forward functions.
+If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details.
+
+### Sequence Parallelism
+
+Sequence parallelism is a special optimization method supported by `Shardformer`. Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism is only used along with 1D tensor parallelism to further reduce memory occupation of activation tensors during computation.
+
+1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradients from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the devices.
+
+2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs along sequence dimension during forward, and Reduce-Scatter to split the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to split the output of `Row Linear` layer of tensor parallel to all devices along sequence dimension, and All-Gather to get the whole gradient during backward.
+
+3. NCCL's implementation of All-Reduce adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared with sequence parallelism and tensor parallelism, it does not introduce additional communication overhead.
+
+4. One important thing to note is that when using sequence parallelism along with `Column Linear` module of tensor parallelism, the complete input needs to be obtained during the backward computation of gradients. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, in the shape of $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, it is possible to overlap the gradient computation with the All-Gather communication operation in our implementation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`).
+
+
diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md
index 7de0c41c10d7..fe0e7a63ba44 100644
--- a/docs/source/zh-Hans/features/shardformer.md
+++ b/docs/source/zh-Hans/features/shardformer.md
@@ -25,77 +25,6 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
出于这种动机,ColossalAI团队开发了**Shardformer**,该功能可以自动为HuggingFace中主流的Transformer模型进行封装,用于张量并行以及流水线并行的训练策略。如此一来,对系统了解不多的用户也可以轻松地在transformers模型上进行并行训练:只需几行代码,用户就可以将模型转变为并行训练的状态。此外,Shardformer也包括了多种优化工具,用于在前向/后向的传递过程中实现加速和节省内存。
-## 用法
-
-### Shardformer的参数配置
-
-Shardformer的配置由类`ShardConfig`的参数控制:
-
-{{ autodoc:colossalai.shardformer.ShardConfig }}
-
-如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。
-
-### 启动Shardformer
-
-#### 1. 通过Booster启动Shardformer (推荐)
-
-通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。
-
-更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
-
-
-#### 2. 通过Shardformer API启动Shardformer (不推荐)
-
-您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。
-
-[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
-是一个通过调用Shardformer的API启动`Shardformer`的示例。
-
-
-### 注意事项
-
-1. 当启用流水线并行时,请不要用常规方式(`model(input)`、`loss.backward()`)进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用`booster.execute_pipeline`方法来进行前向/后向传递。
-
-2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
-
-3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类:
- ```python
- from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
- from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
- ```
- 并且使用这些导入的类初始化模型。
-
-
-### 序列并行 Sequence Parallelism
-
-在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。
-
-1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。
-
-2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。
-
-3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。
-
-4. 需要注意的一点是,在张量并行的 “Column Linear” 中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。
-
-
-## Shardformer的工作原理
-
-通常来说,Shardformer通过以下四种“替换”进行工作:
-
-1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。
-分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。
-
-2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。
-
-3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。
-
-4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。
-如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。
-
-所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。
-
-
## 支持信息
模型/功能 兼容性矩阵:
@@ -262,4 +191,74 @@ Shardformer的配置由类`ShardConfig`的参数控制:
随着未来更多模型和优化工具的出现,我们支持的模型/优化工具将会变得越来越多。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的[Issues](https://github.com/hpcaitech/ColossalAI/issues)板块参与讨论。
+## 用法
+
+### Shardformer的参数配置
+
+Shardformer的配置由类`ShardConfig`的参数控制:
+
+{{ autodoc:colossalai.shardformer.ShardConfig }}
+
+如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。
+
+### 启动Shardformer
+
+#### 1. 通过Booster启动Shardformer (推荐)
+
+通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。
+
+更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
+
+
+#### 2. 通过Shardformer API启动Shardformer (不推荐)
+
+您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。
+
+[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
+是一个通过调用Shardformer的API启动`Shardformer`的示例。
+
+
+### 注意事项
+
+1. 当启用流水线并行时,请不要用常规方式(`model(input)`、`loss.backward()`)进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用`booster.execute_pipeline`方法来进行前向/后向传递。
+
+2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
+
+3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类:
+ ```python
+ from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
+ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
+ ```
+ 并且使用这些导入的类初始化模型。
+
+
+## Shardformer的工作原理
+
+通常来说,Shardformer通过以下四种“替换”进行工作:
+
+1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。
+分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。
+
+2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。
+
+3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。
+
+4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。
+如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。
+
+所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。
+
+### 序列并行 Sequence Parallelism
+
+序列并行是`Shardformer`支持的一种特殊的优化方法。在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。
+
+1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。
+
+2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。
+
+3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。
+
+4. 需要注意的一点是,在张量并行的 `Column Linear` 层中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。
+
+
From ac2797996b362e5bded4d0eec18ef96efc12b086 Mon Sep 17 00:00:00 2001
From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Date: Fri, 15 Sep 2023 17:53:13 +0800
Subject: [PATCH 14/58] [shardformer] add custom policy in hybrid parallel
plugin (#4718)
* add custom policy
* update assert
---
.../booster/plugin/hybrid_parallel_plugin.py | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 3fbeebcc4110..d15245523226 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -22,6 +22,7 @@
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
@@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
- ddp_config: dict) -> None:
+ ddp_config: dict, custom_policy: Policy) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
shardformer = ShardFormer(shard_config)
- module, self.shared_params = shardformer.optimize(module)
+ if custom_policy is not None:
+ assert isinstance(custom_policy, object)
+ module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
# setting process groups for shared parameters
self.shared_param_process_groups = []
@@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""
def __init__(self,
@@ -302,7 +306,8 @@ def __init__(self,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True) -> None:
+ overlap_communication: bool = True,
+ custom_policy: Policy = None) -> None:
super().__init__()
assert dist.get_world_size() % (
@@ -326,6 +331,7 @@ def __init__(self,
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
+ self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
@@ -405,7 +411,7 @@ def configure(
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
- self.ddp_config)
+ self.ddp_config, self.custom_policy)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
From 4c4482f3adb56943a150b8b7ed886e2218fc98d5 Mon Sep 17 00:00:00 2001
From: flybird11111 <1829166702@qq.com>
Date: Fri, 15 Sep 2023 18:45:44 +0800
Subject: [PATCH 15/58] [example] llama2 add fine-tune example (#4673)
* [shardformer] update shardformer readme
[shardformer] update shardformer readme
[shardformer] update shardformer readme
* [shardformer] update llama2/opt finetune example and shardformer update to llama2
* [shardformer] update llama2/opt finetune example and shardformer update to llama2
* [shardformer] update llama2/opt finetune example and shardformer update to llama2
* [shardformer] change dataset
* [shardformer] change dataset
* [shardformer] fix CI
* [shardformer] fix
* [shardformer] fix
* [shardformer] fix
* [shardformer] fix
* [shardformer] fix
[example] update opt example
[example] resolve comments
fix
fix
* [example] llama2 add finetune example
* [example] llama2 add finetune example
* [example] llama2 add finetune example
* [example] llama2 add finetune example
* fix
* update llama2 example
* update llama2 example
* fix
* update llama2 example
* update llama2 example
* update llama2 example
* update llama2 example
* update llama2 example
* update llama2 example
* Update requirements.txt
* update llama2 example
* update llama2 example
* update llama2 example
---
.../hybrid_parallel_checkpoint_io.py | 4 +-
examples/language/bert/finetune.py | 7 +-
examples/language/llama2/README.md | 39 ++-
examples/language/llama2/finetune.py | 295 ++++++++++++++++++
examples/language/llama2/pretrain.py | 79 +++--
examples/language/llama2/requirements.txt | 2 +-
examples/language/opt/README.md | 7 +-
examples/language/opt/requirements.txt | 4 +-
8 files changed, 402 insertions(+), 35 deletions(-)
create mode 100644 examples/language/llama2/finetune.py
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 6eee3ace0308..270fd8564754 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -13,6 +13,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from colossalai.cluster import DistCoordinator
from colossalai.interface import OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
@@ -71,6 +72,7 @@ def __init__(self,
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
+ self.coordinator = DistCoordinator()
@staticmethod
def _model_sharder(model: nn.Module,
@@ -655,7 +657,7 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor,
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
- state_[k] = v.detach().clone().cpu()
+ state_[k] = v.detach().clone().cpu()
return state_
diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py
index 2e8780806f19..fb6e4332c2f9 100644
--- a/examples/language/bert/finetune.py
+++ b/examples/language/bert/finetune.py
@@ -129,14 +129,13 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
total_step = len(train_dataloader)
model.train()
optimizer.zero_grad()
train_dataloader_iter = iter(train_dataloader)
- with tqdm(range(total_step),
- desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
- disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
+ with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not print_flag) as pbar:
# Forward pass
for _ in pbar:
if use_pipeline:
@@ -192,13 +191,13 @@ def main():
model_name = "albert-xxlarge-v2"
else:
raise RuntimeError
+
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
coordinator = DistCoordinator()
- # local_batch_size = BATCH_SIZE // coordinator.world_size
lr = LEARNING_RATE * coordinator.world_size
# ==============================
diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md
index c8fc86d29d97..83ef99b57d42 100644
--- a/examples/language/llama2/README.md
+++ b/examples/language/llama2/README.md
@@ -92,7 +92,7 @@ Make sure master node can access all nodes (including itself) by ssh without pas
Here is details about CLI arguments:
- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
-- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
+- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
@@ -195,3 +195,40 @@ If you run the above command successfully, you will get the following results:
year={2023}
}
```
+
+
+# Fine-tune Llama2
+
+We also provide a example to fine-tune llama2 in `finetune.py`,
+
+Make sure master node can access all nodes (including itself) by ssh without password.
+
+Here is details about CLI arguments:
+
+- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag.
+- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
+- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`.
+- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`.
+- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
+- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
+- Learning rate: `--lr`. The default value is 3e-4.
+- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
+- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
+- Max length: `-l`, `--max_length`. The default value is 4096.
+- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
+- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
+- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
+- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
+- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
+- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
+- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
+
+
+```shell
+torchrun --standalone --nproc_per_node 8 finetune.py \
+ --plugin "hybrid_parallel" \
+ --dataset "yizhongw/self_instruct" \
+ --model_path "/path/llama" \
+ --task_name "super_natural_instructions" \
+ --save_dir "/path/output"
+```
diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py
new file mode 100644
index 000000000000..0efbf193c9a9
--- /dev/null
+++ b/examples/language/llama2/finetune.py
@@ -0,0 +1,295 @@
+import argparse
+import math
+import os
+import resource
+from contextlib import nullcontext
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from attn import SUPPORT_XFORMERS, replace_xformers
+from data_utils import load_json, prepare_dataloader, save_json
+from datasets import load_dataset
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+def get_model_numel(model: nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def format_numel_str(numel: int) -> str:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ if numel >= B:
+ return f'{numel / B:.2f} B'
+ elif numel >= M:
+ return f'{numel / M:.2f} M'
+ elif numel >= K:
+ return f'{numel / K:.2f} K'
+ else:
+ return f'{numel}'
+
+
+def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
+ texts = [sample['prompt'] + sample['completion'] for sample in batch]
+ data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
+ data = {k: v.cuda() for k, v in data.items()}
+ data['labels'] = data['input_ids'].clone()
+ return data
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+ tensor.div_(dist.get_world_size())
+ return tensor
+
+
+def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
+ batch_size: int, coordinator: DistCoordinator, save_dir: str):
+ save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
+ os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
+
+ booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
+ booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
+ running_states = {
+ 'epoch': epoch,
+ 'step': step,
+ 'sample_start_index': step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, 'running_states.json'))
+
+
+def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
+ load_dir: str) -> Tuple[int, int, int]:
+ booster.load_model(model, os.path.join(load_dir, 'model'))
+ booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
+ booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
+ running_states = load_json(os.path.join(load_dir, 'running_states.json'))
+ return running_states['epoch'], running_states['step'], running_states['sample_start_index']
+
+
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+
+def main():
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune")
+ parser.add_argument('-p',
+ '--plugin',
+ choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
+ default='gemini',
+ help='Choose which plugin to use')
+ parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path')
+ parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run')
+ parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
+ parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
+ parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
+ parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
+ parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
+ parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
+ parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
+ parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
+ parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
+ parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
+ parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
+ parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
+ parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
+ args = parser.parse_args()
+
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == 'gemini':
+ plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
+ elif args.plugin == 'gemini_auto':
+ plugin = GeminiPlugin(precision=args.mixed_precision,
+ placement_policy='auto',
+ initial_scale=2**16,
+ max_norm=args.grad_clip)
+ elif args.plugin == 'zero2':
+ plugin = LowLevelZeroPlugin(stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip)
+ elif args.plugin == 'zero2_cpu':
+ plugin = LowLevelZeroPlugin(stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip)
+ elif args.plugin == 'hybrid_parallel':
+ # modify the param accordingly, default configuration is for llama2-7b
+ plugin = HybridParallelPlugin(tp_size=4,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_jit_fused=False,
+ zero_stage=0,
+ precision='fp32',
+ initial_scale=1)
+ else:
+ raise ValueError(f'Unknown plugin {args.plugin}')
+
+ booster = Booster(plugin=plugin)
+
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
+
+ # ==============================
+ # Initialize Tensorboard
+ # ==============================
+ if print_flag:
+ os.makedirs(args.tensorboard_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+
+ # ==============================
+ # Initialize Model, Optimizer and LR Scheduler
+ # ==============================
+
+ config = LlamaConfig.from_pretrained(args.model_path)
+ # use lazy init when using GeminiPlugin
+ init_ctx = LazyInitContext(
+ default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
+
+ with init_ctx:
+ model = LlamaForCausalLM(config)
+
+ # ==============================
+ # Initialize Tokenizer, Dataset and Dataloader
+ # ==============================
+ tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
+ # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
+ tokenizer.pad_token = tokenizer.unk_token
+
+ dataset = load_dataset(args.dataset, args.task_name)
+ train_ds = dataset['train']
+ dataloader = prepare_dataloader(train_ds,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=partial(tokenize_batch_for_finetune,
+ tokenizer=tokenizer,
+ max_length=args.max_length))
+
+ if args.grad_checkpoint:
+ model.gradient_checkpointing_enable()
+ if args.flash_attention:
+ assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
+ replace_xformers(model)
+
+ model_numel = get_model_numel(model)
+ coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
+
+ optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
+ total_step = args.num_epochs * len(dataloader)
+ lr_scheduler = CosineAnnealingWarmupLR(optimizer,
+ total_steps=total_step,
+ warmup_steps=math.ceil(total_step * 0.03),
+ eta_min=0.1 * args.lr)
+ default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
+ optimizer,
+ dataloader=dataloader,
+ lr_scheduler=lr_scheduler)
+ torch.set_default_dtype(torch.float)
+
+ booster.load_model(model, args.model_path)
+
+ coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
+ coordinator.print_on_master(
+ f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
+
+ # load checkpoint if specified
+ start_epoch = 0
+ start_step = 0
+ sampler_start_idx = 0
+ if args.load is not None:
+ coordinator.print_on_master('Loading checkpoint')
+ start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
+ coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
+
+ num_steps_per_epoch = len(dataloader)
+
+ # if resume training, set the sampler start index to the correct value
+ dataloader.sampler.set_start_index(sampler_start_idx)
+ for epoch in range(start_epoch, args.num_epochs):
+ dataloader.sampler.set_epoch(epoch)
+ step_nums = num_steps_per_epoch - start_step
+ dataloader_iter = iter(dataloader)
+
+ with tqdm(range(step_nums),
+ desc=f'Epoch {epoch}',
+ disable=not print_flag,
+ total=num_steps_per_epoch,
+ initial=start_step) as pbar:
+ for step in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ loss = outputs["loss"]
+ else:
+ batch = next(dataloader_iter)
+ outputs = model(**batch)
+ loss = outputs[0]
+ booster.backward(loss, optimizer)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ if not use_pipeline:
+ all_reduce_mean(loss)
+ if print_flag:
+ pbar.set_postfix({'loss': loss.item()})
+ writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
+
+ if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
+ coordinator.print_on_master(f'Saving checkpoint')
+ save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
+ args.save_dir)
+ coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
+ # the continue epochs are not resumed, so we need to reset the sampler start index and start step
+ dataloader.sampler.set_start_index(0)
+ start_step = 0
+
+ coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py
index b72a3019692e..0eeac4035401 100644
--- a/examples/language/llama2/pretrain.py
+++ b/examples/language/llama2/pretrain.py
@@ -21,7 +21,7 @@
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -65,9 +65,10 @@ def format_numel_str(numel: int) -> str:
return f'{numel}'
-def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
+def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample['text'] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
+ data = {k: v.cuda() for k, v in data.items()}
data['labels'] = data['input_ids'].clone()
return data
@@ -104,6 +105,10 @@ def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler:
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+
def main():
# ==============================
# Parse Arguments
@@ -112,7 +117,7 @@ def main():
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
parser.add_argument('-p',
'--plugin',
- choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'],
+ choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
default='gemini',
help='Choose which plugin to use')
parser.add_argument('-d',
@@ -142,13 +147,6 @@ def main():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
- # ==============================
- # Initialize Tensorboard
- # ==============================
- if coordinator.is_master():
- os.makedirs(args.tensorboard_dir, exist_ok=True)
- writer = SummaryWriter(args.tensorboard_dir)
-
# ==============================
# Initialize Booster
# ==============================
@@ -170,11 +168,32 @@ def main():
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip)
+ elif args.plugin == 'hybrid_parallel':
+ # modify the param accordingly, default configuration is for llama2-7b
+ plugin = HybridParallelPlugin(tp_size=4,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_jit_fused=False,
+ zero_stage=0,
+ precision='fp32',
+ initial_scale=1)
else:
raise ValueError(f'Unknown plugin {args.plugin}')
booster = Booster(plugin=plugin)
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
+
+ # ==============================
+ # Initialize Tensorboard
+ # ==============================
+ if print_flag:
+ os.makedirs(args.tensorboard_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+
# ==============================
# Initialize Tokenizer, Dataset and Dataloader
# ==============================
@@ -188,12 +207,15 @@ def main():
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
- collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length))
+ collate_fn=partial(tokenize_batch_for_pretrain,
+ tokenizer=tokenizer,
+ max_length=args.max_length))
# ==============================
# Initialize Model, Optimizer and LR Scheduler
# ==============================
config = MODEL_CONFIGS[args.config]
+ # use lazy init when using GeminiPlugin
init_ctx = LazyInitContext(
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
@@ -236,27 +258,42 @@ def main():
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
num_steps_per_epoch = len(dataloader)
+
# if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
- with tqdm(enumerate(dataloader),
+ step_nums = num_steps_per_epoch - start_step
+ dataloader_iter = iter(dataloader)
+
+ with tqdm(range(step_nums),
desc=f'Epoch {epoch}',
- disable=not coordinator.is_master(),
+ disable=not print_flag,
total=num_steps_per_epoch,
initial=start_step) as pbar:
- for step, batch in pbar:
- batch = {k: v.cuda() for k, v in batch.items()}
- outputs = model(**batch)
- loss = outputs[0]
- booster.backward(loss, optimizer)
+ for step in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ loss = outputs["loss"]
+ else:
+ batch = next(dataloader_iter)
+ outputs = model(**batch)
+ loss = outputs[0]
+ booster.backward(loss, optimizer)
+
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
- all_reduce_mean(loss)
- pbar.set_postfix({'loss': loss.item()})
- if coordinator.is_master():
+ if not use_pipeline:
+ all_reduce_mean(loss)
+ if print_flag:
+ pbar.set_postfix({'loss': loss.item()})
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt
index 3ddf21ffe534..6b475682dad0 100644
--- a/examples/language/llama2/requirements.txt
+++ b/examples/language/llama2/requirements.txt
@@ -1,4 +1,4 @@
-colossalai>=0.3.0
+colossalai>=0.3.2
datasets
numpy
torch>=1.12.0,<=2.0.0
diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md
index 37e1ff4d9008..af1e794374ed 100644
--- a/examples/language/opt/README.md
+++ b/examples/language/opt/README.md
@@ -23,9 +23,9 @@ The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI)
## Our Modifications
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
-the tokenization).
+the tokenization).
-We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
+We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, HybridParallelPlugin and GeminiPlugin.
## Run Demo
@@ -48,6 +48,3 @@ You can run benchmark for OPT model by running the following script:
bash run_benchmark.sh
```
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing.
-
-
-
diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt
index 4422216e6a1c..45bfbc37195f 100644
--- a/examples/language/opt/requirements.txt
+++ b/examples/language/opt/requirements.txt
@@ -1,4 +1,4 @@
-colossalai >= 0.1.12
+colossalai >= 0.3.2
torch >= 1.8.1
datasets >= 1.8.0
-transformers >= 4.20.0
\ No newline at end of file
+transformers >= 4.30.2
From d151dcab740eaae784333c93d85100c3641bd115 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Fri, 15 Sep 2023 21:04:07 +0800
Subject: [PATCH 16/58] [doc] explaination of loading large pretrained models
(#4741)
---
docs/source/en/basics/booster_checkpoint.md | 24 +++++++++++++++++++
.../zh-Hans/basics/booster_checkpoint.md | 24 +++++++++++++++++++
2 files changed, 48 insertions(+)
diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md
index 4ef35dc9a9bb..ea6c11ae2cdc 100644
--- a/docs/source/en/basics/booster_checkpoint.md
+++ b/docs/source/en/basics/booster_checkpoint.md
@@ -19,6 +19,30 @@ Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint
Model must be boosted by `colossalai.booster.Booster` before loading. It will detect the checkpoint format automatically, and load in corresponding way.
+If you want to load a pretrained model from Huggingface while the model is too large to be directly loaded through `from_pretrained` on a single device, a recommended way is to download the pretrained weights to a local directory, and use `booster.load` to load from that directory after boosting the model. Also, the model should be initialized under lazy initialization context to avoid OOM. Here is an example pseudocode:
+```python
+from colossalai.lazy import LazyInitContext
+from huggingface_hub import snapshot_download
+...
+
+# Initialize model under lazy init context
+init_ctx = LazyInitContext(default_device=get_current_device)
+with init_ctx:
+ model = LlamaForCausalLM(config)
+
+...
+
+# Wrap the model through Booster.boost
+model, optimizer, _, _, _ = booster.boost(model, optimizer)
+
+# download huggingface pretrained model to local directory.
+model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp")
+
+# load model using booster.load
+booster.load(model, model_dir)
+...
+```
+
## Optimizer Checkpoint
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md
index 02557ad47d56..1ff2e330521c 100644
--- a/docs/source/zh-Hans/basics/booster_checkpoint.md
+++ b/docs/source/zh-Hans/basics/booster_checkpoint.md
@@ -19,6 +19,30 @@
模型在加载前必须被 `colossalai.booster.Booster` 封装。它会自动检测 checkpoint 格式,并以相应的方式加载。
+如果您想从Huggingface加载预训练好的模型,但模型太大以至于无法在单个设备上通过“from_pretrained”直接加载,推荐的方法是将预训练的模型权重下载到本地,并在封装模型后使用`booster.load`直接从本地路径加载。为了避免内存不足,模型需要在`Lazy Initialization`的环境下初始化。以下是示例伪代码:
+```python
+from colossalai.lazy import LazyInitContext
+from huggingface_hub import snapshot_download
+...
+
+# Initialize model under lazy init context
+init_ctx = LazyInitContext(default_device=get_current_device)
+with init_ctx:
+ model = LlamaForCausalLM(config)
+
+...
+
+# Wrap the model through Booster.boost
+model, optimizer, _, _, _ = booster.boost(model, optimizer)
+
+# download huggingface pretrained model to local directory.
+model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp")
+
+# load model using booster.load
+booster.load(model, model_dir)
+...
+```
+
## 优化器 Checkpoint
From 32e7f99416c846402d6098419777edee3ddbce7b Mon Sep 17 00:00:00 2001
From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Date: Mon, 18 Sep 2023 09:44:27 +0800
Subject: [PATCH 17/58] [kernel] update triton init #4740 (#4740)
---
colossalai/kernel/triton/__init__.py | 30 ++++++++++++++++++----------
1 file changed, 19 insertions(+), 11 deletions(-)
diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py
index 5840ad2918be..75812db036a9 100644
--- a/colossalai/kernel/triton/__init__.py
+++ b/colossalai/kernel/triton/__init__.py
@@ -1,12 +1,20 @@
-from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
-from .copy_kv_cache_dest import copy_kv_cache_to_dest
-from .fused_layernorm import layer_norm
-from .rms_norm import rmsnorm_forward
-from .rotary_embedding_kernel import rotary_embedding_fwd
-from .softmax import softmax
-from .token_attention_kernel import token_attention_fwd
+try:
+ import triton
+ HAS_TRITON = True
-__all__ = [
- "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
- "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
-]
+ from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
+ from .copy_kv_cache_dest import copy_kv_cache_to_dest
+ from .fused_layernorm import layer_norm
+ from .rms_norm import rmsnorm_forward
+ from .rotary_embedding_kernel import rotary_embedding_fwd
+ from .softmax import softmax
+ from .token_attention_kernel import token_attention_fwd
+
+ __all__ = [
+ "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
+ "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
+ ]
+
+except ImportError:
+ HAS_TRITON = False
+ print("Triton is not installed. Please install Triton to use Triton kernels.")
From b5f9e37c709656b286940f1b5e05abddfa257e3d Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Mon, 18 Sep 2023 16:31:06 +0800
Subject: [PATCH 18/58] [legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692)
* [legacy] remove cli of benchmark and update optim (#4690)
* [legacy] remove cli of benchmark and update optim
* [doc] fix cli doc test
* [legacy] fix engine clip grad norm
* [legacy] remove outdated colo tensor (#4694)
* [legacy] remove outdated colo tensor
* [test] fix test import
* [legacy] move outdated zero to legacy (#4696)
* [legacy] clean up utils (#4700)
* [legacy] clean up utils
* [example] update examples
* [legacy] clean up amp
* [legacy] fix amp module
* [legacy] clean up gpc (#4742)
* [legacy] clean up context
* [legacy] clean core, constants and global vars
* [legacy] refactor initialize
* [example] fix examples ci
* [example] fix examples ci
* [legacy] fix tests
* [example] fix gpt example
* [example] fix examples ci
* [devops] fix ci installation
* [example] fix examples ci
---
.github/workflows/doc_test_on_pr.yml | 2 +-
.github/workflows/doc_test_on_schedule.yml | 2 +-
.../workflows/example_check_on_dispatch.yml | 2 +-
.github/workflows/example_check_on_pr.yml | 2 +-
.../workflows/example_check_on_schedule.yml | 2 +-
colossalai/__init__.py | 11 +-
colossalai/amp/__init__.py | 54 --
colossalai/amp/naive_amp/__init__.py | 60 --
.../auto_parallel/offload/amp_optimizer.py | 4 +-
colossalai/checkpoint_io/utils.py | 4 -
colossalai/cli/benchmark/__init__.py | 28 -
colossalai/cli/benchmark/benchmark.py | 105 ---
colossalai/cli/benchmark/models.py | 18 -
colossalai/cli/benchmark/utils.py | 159 -----
colossalai/cli/cli.py | 2 -
colossalai/context/__init__.py | 12 +-
colossalai/context/moe_context.py | 7 +-
colossalai/core.py | 6 -
colossalai/fx/passes/shard_1d_pass.py | 16 +-
colossalai/initialize.py | 326 +--------
colossalai/legacy/__init__.py | 9 +
colossalai/legacy/amp/__init__.py | 54 ++
colossalai/{ => legacy}/amp/amp_type.py | 0
.../{ => legacy}/amp/apex_amp/__init__.py | 0
.../{ => legacy}/amp/apex_amp/apex_amp.py | 6 +-
colossalai/legacy/amp/naive_amp/__init__.py | 60 ++
.../amp/naive_amp/_fp16_optimizer.py | 9 +-
.../{ => legacy}/amp/naive_amp/_utils.py | 0
.../{ => legacy}/amp/naive_amp/naive_amp.py | 10 +-
.../{ => legacy}/amp/torch_amp/__init__.py | 0
.../amp/torch_amp/_grad_scaler.py | 4 +-
.../{ => legacy}/amp/torch_amp/torch_amp.py | 6 +-
colossalai/legacy/communication/collective.py | 14 +-
colossalai/legacy/communication/p2p.py | 4 +-
colossalai/legacy/communication/p2p_v2.py | 4 +-
colossalai/legacy/communication/ring.py | 4 +-
colossalai/legacy/communication/utils.py | 4 +-
colossalai/{ => legacy}/constants.py | 0
colossalai/legacy/context/__init__.py | 4 +
.../{ => legacy}/context/parallel_context.py | 66 +-
.../{ => legacy}/context/parallel_mode.py | 0
.../process_group_initializer/__init__.py | 2 +-
.../initializer_1d.py | 2 +-
.../initializer_2d.py | 2 +-
.../initializer_2p5d.py | 2 +-
.../initializer_3d.py | 2 +-
.../initializer_data.py | 0
.../initializer_model.py | 0
.../initializer_pipeline.py | 0
.../initializer_sequence.py | 0
.../initializer_tensor.py | 0
.../process_group_initializer.py | 0
.../{ => legacy}/context/random/__init__.py | 0
.../{ => legacy}/context/random/_helper.py | 12 +-
.../context/random/seed_manager.py | 10 +-
colossalai/legacy/core.py | 6 +
colossalai/legacy/engine/_base_engine.py | 10 +-
.../_gradient_accumulation.py | 6 +-
.../_data_parallel_gradient_handler.py | 4 +-
.../gradient_handler/_moe_gradient_handler.py | 4 +-
.../_pipeline_parallel_gradient_handler.py | 2 +-
.../_sequence_parallel_gradient_handler.py | 4 +-
.../engine/schedule/_pipeline_schedule.py | 10 +-
.../engine/schedule/_pipeline_schedule_v2.py | 4 +-
colossalai/{ => legacy}/global_variables.py | 0
colossalai/legacy/initialize.py | 472 +++++++++++++
colossalai/legacy/nn/__init__.py | 1 -
colossalai/legacy/nn/_ops/__init__.py | 10 +-
colossalai/legacy/nn/_ops/_utils.py | 5 +-
colossalai/legacy/nn/_ops/addmm.py | 90 ---
colossalai/legacy/nn/_ops/batch_norm.py | 33 -
colossalai/legacy/nn/_ops/element_wise.py | 250 -------
colossalai/legacy/nn/_ops/embedding.py | 142 ----
colossalai/legacy/nn/_ops/embedding_bag.py | 127 ----
colossalai/legacy/nn/_ops/layernorm.py | 28 -
colossalai/legacy/nn/_ops/linear.py | 171 -----
colossalai/legacy/nn/_ops/loss.py | 51 --
colossalai/legacy/nn/_ops/view.py | 96 ---
colossalai/legacy/nn/layer/base_layer.py | 4 +-
.../nn/layer/colossalai_layer/dropout.py | 2 +-
.../legacy/nn/layer/parallel_1d/_operation.py | 2 +-
.../legacy/nn/layer/parallel_1d/_utils.py | 4 +-
.../legacy/nn/layer/parallel_1d/layers.py | 10 +-
.../legacy/nn/layer/parallel_2d/_operation.py | 40 +-
.../legacy/nn/layer/parallel_2d/_utils.py | 6 +-
.../legacy/nn/layer/parallel_2d/layers.py | 11 +-
.../nn/layer/parallel_2p5d/_operation.py | 34 +-
.../legacy/nn/layer/parallel_2p5d/_utils.py | 6 +-
.../legacy/nn/layer/parallel_2p5d/layers.py | 10 +-
.../legacy/nn/layer/parallel_3d/_operation.py | 44 +-
.../legacy/nn/layer/parallel_3d/_utils.py | 12 +-
.../legacy/nn/layer/parallel_3d/layers.py | 18 +-
.../nn/layer/parallel_sequence/_operation.py | 4 +-
.../nn/layer/parallel_sequence/layers.py | 6 +-
colossalai/legacy/nn/layer/utils/common.py | 6 +-
colossalai/legacy/nn/layer/vanilla/layers.py | 6 +-
.../nn/layer/wrapper/pipeline_wrapper.py | 4 +-
colossalai/legacy/nn/loss/__init__.py | 2 +-
colossalai/legacy/nn/loss/loss_1d.py | 4 +-
colossalai/legacy/nn/loss/loss_2d.py | 4 +-
colossalai/legacy/nn/loss/loss_2p5d.py | 4 +-
colossalai/legacy/nn/loss/loss_3d.py | 4 +-
colossalai/legacy/nn/metric/accuracy_3d.py | 2 +-
.../legacy/nn/parallel/data_parallel.py | 6 +-
.../parallel_cached_embedding.py | 3 +-
.../parallel_cached_embedding_tablewise.py | 2 +-
..._cached_embedding_tablewise_split_cache.py | 2 +-
.../legacy/nn/parallel/layers/colo_module.py | 4 +-
.../legacy/nn/parallel/layers/embedding.py | 2 +-
.../legacy/nn/parallel/layers/linear.py | 2 +-
.../legacy/nn/parallel/layers/module_utils.py | 3 +-
colossalai/legacy/pipeline/__init__.py | 4 +
.../{ => legacy}/pipeline/layer_spec.py | 6 +-
.../legacy/pipeline/middleware/__init__.py | 3 +
.../pipeline/middleware/adaptor/__init__.py | 2 +-
.../pipeline/middleware/adaptor/fx.py | 34 +-
.../{ => legacy}/pipeline/middleware/topo.py | 86 +--
.../{ => legacy}/pipeline/pipelinable.py | 26 +-
.../pipeline/pipeline_process_group.py | 6 +-
colossalai/legacy/pipeline/rpc/__init__.py | 4 +
.../pipeline/rpc/_pipeline_base.py | 6 +-
.../pipeline/rpc/_pipeline_schedule.py | 8 +-
colossalai/{ => legacy}/pipeline/rpc/utils.py | 2 +-
colossalai/{ => legacy}/pipeline/utils.py | 0
colossalai/legacy/tensor/__init__.py | 17 +
.../{ => legacy}/tensor/compute_spec.py | 0
colossalai/{ => legacy}/tensor/const.py | 0
.../{ => legacy}/tensor/dist_spec_mgr.py | 6 +-
colossalai/{ => legacy}/tensor/distspec.py | 0
colossalai/{ => legacy}/tensor/op_wrapper.py | 5 +-
.../{ => legacy}/tensor/process_group.py | 0
colossalai/{ => legacy}/tensor/tensor_spec.py | 4 +-
colossalai/legacy/trainer/_trainer.py | 3 +-
.../legacy/trainer/hooks/_checkpoint_hook.py | 2 +-
colossalai/legacy/trainer/hooks/_log_hook.py | 11 +-
.../legacy/trainer/hooks/_metric_hook.py | 7 +-
colossalai/legacy/utils/__init__.py | 53 ++
.../utils/activation_checkpoint.py | 16 +-
.../legacy/utils/checkpoint/__init__.py | 3 +
.../utils/checkpoint/module_checkpoint.py | 21 +-
.../{ => legacy}/utils/checkpoint/utils.py | 128 ++--
.../{ => legacy}/utils/checkpointing.py | 8 +-
colossalai/legacy/utils/common.py | 434 ++++++++++++
.../utils/data_sampler/__init__.py | 0
.../utils/data_sampler/base_sampler.py | 0
.../data_sampler/data_parallel_sampler.py | 4 +-
colossalai/{ => legacy}/utils/memory.py | 18 +-
.../{ => legacy}/utils/profiler/__init__.py | 0
.../{ => legacy}/utils/profiler/extention.py | 0
.../utils/profiler/legacy/__init__.py | 12 +-
.../utils/profiler/legacy/comm_profiler.py | 619 +++++++++---------
.../utils/profiler/legacy/pcie_profiler.py | 298 ++++-----
.../utils/profiler/legacy/prof_utils.py | 263 ++++----
.../{ => legacy}/utils/profiler/profiler.py | 4 +-
.../profiler/stateful_tensor_mem_extention.py | 2 +-
.../{zero/legacy => legacy/zero}/__init__.py | 0
.../legacy => legacy/zero}/gemini/__init__.py | 0
.../zero}/gemini/gemini_context.py | 0
.../zero}/gemini/ophooks/__init__.py | 0
.../gemini/ophooks/_shard_grad_ophook.py | 0
.../gemini/ophooks/_shard_param_ophook.py | 0
.../gemini/ophooks/runtime_mem_tracer_hook.py | 2 +-
.../zero}/gemini/ophooks/utils.py | 0
.../zero}/gemini/paramhooks/__init__.py | 0
.../zero}/gemini/paramhooks/_param_hookmgr.py | 0
.../zero}/gemini/stateful_tensor.py | 0
.../zero}/gemini/stateful_tensor_mgr.py | 0
.../zero}/gemini/tensor_placement_policy.py | 2 +-
.../zero}/gemini/tensor_utils.py | 0
.../zero}/init_ctx/__init__.py | 0
.../zero}/init_ctx/init_context.py | 12 +-
.../zero}/shard_utils/__init__.py | 0
.../zero}/shard_utils/base_shard_strategy.py | 2 +-
.../bucket_tensor_shard_strategy.py | 2 +-
.../zero}/shard_utils/commons.py | 0
.../shard_utils/tensor_shard_strategy.py | 8 +-
.../zero}/sharded_model/__init__.py | 0
.../zero}/sharded_model/_utils.py | 2 +-
.../zero}/sharded_model/reduce_scatter.py | 0
.../zero}/sharded_model/sharded_model_v2.py | 24 +-
.../zero}/sharded_model/utils.py | 2 +-
.../zero}/sharded_model/zero_hook.py | 8 +-
.../zero}/sharded_optim/__init__.py | 0
.../zero}/sharded_optim/sharded_optim_v2.py | 18 +-
.../zero}/sharded_param/__init__.py | 0
.../zero}/sharded_param/sharded_param.py | 4 +-
.../zero}/sharded_param/sharded_tensor.py | 2 +-
colossalai/logging/logger.py | 8 -
colossalai/nn/layer/__init__.py | 2 +-
colossalai/nn/layer/moe/experts.py | 4 +-
colossalai/nn/layer/moe/layers.py | 2 +-
colossalai/nn/loss/__init__.py | 2 +-
colossalai/nn/optimizer/__init__.py | 7 +-
.../nn/optimizer/colossalai_optimizer.py | 44 --
colossalai/pipeline/__init__.py | 13 +-
colossalai/pipeline/middleware/__init__.py | 3 -
colossalai/pipeline/rpc/__init__.py | 4 -
colossalai/pipeline/schedule/__init__.py | 2 +
colossalai/tensor/__init__.py | 11 +-
colossalai/utils/__init__.py | 59 +-
colossalai/utils/checkpoint/__init__.py | 3 -
colossalai/utils/common.py | 438 +------------
colossalai/utils/cuda.py | 11 +-
colossalai/utils/moe.py | 106 +--
colossalai/zero/gemini/colo_init_context.py | 3 +-
.../zero/gemini/memory_tracer/__init__.py | 5 +-
.../memory_tracer/chunk_memstats_collector.py | 2 +-
.../gemini/memory_tracer/memory_monitor.py | 3 +-
.../memory_tracer/memstats_collector.py | 2 +-
.../memory_tracer/runtime_mem_tracer.py | 6 +-
colossalai/zero/gemini/placement_policy.py | 2 +-
colossalai/zero/low_level/_utils.py | 3 -
docs/README.md | 2 +-
.../advanced_tutorials/add_your_parallel.md | 2 +-
.../train_gpt_using_hybrid_parallelism.md | 2 +-
docs/source/en/basics/command_line_tool.md | 22 +-
.../advanced_tutorials/add_your_parallel.md | 2 +-
.../train_gpt_using_hybrid_parallelism.md | 2 +-
.../zh-Hans/basics/command_line_tool.md | 20 +-
.../roberta/pretraining/pretrain_utils.py | 2 +-
.../roberta/pretraining/run_pretraining.py | 2 +-
.../roberta/pretraining/utils/exp_util.py | 2 +-
examples/images/dreambooth/test_ci.sh | 42 +-
.../dreambooth/train_dreambooth_colossalai.py | 9 +-
.../train_dreambooth_colossalai_lora.py | 4 +-
.../auto_parallel/auto_parallel_with_gpt.py | 2 +-
.../pipeline_parallel/train_gpt_pp.py | 8 +-
examples/language/gpt/gemini/run_gemini.sh | 7 +-
.../language/gpt/gemini/train_gpt_demo.py | 4 +-
examples/language/gpt/test_ci.sh | 2 +-
examples/language/gpt/titans/model/embed.py | 4 +-
examples/language/gpt/titans/model/gpt1d.py | 4 +-
.../gpt/titans/model/pipeline_gpt1d.py | 6 +-
examples/language/gpt/titans/train_gpt.py | 6 +-
.../auto_parallel_with_resnet.py | 2 +-
examples/tutorial/auto_parallel/test_ci.sh | 8 +-
examples/tutorial/hybrid_parallel/config.py | 2 +-
examples/tutorial/hybrid_parallel/train.py | 6 +-
.../tutorial/large_batch_optimizer/config.py | 2 +-
.../tutorial/large_batch_optimizer/test_ci.sh | 7 +-
.../tutorial/large_batch_optimizer/train.py | 2 +-
examples/tutorial/opt/opt/colossalai_zero.py | 2 +-
examples/tutorial/opt/opt/context.py | 4 +-
examples/tutorial/opt/opt/run_clm.py | 9 +-
examples/tutorial/opt/opt/test_ci.sh | 32 +-
examples/tutorial/sequence_parallel/config.py | 2 +-
.../sequence_parallel/data/__init__.py | 34 +-
.../sequence_parallel/data/bert_helper.py | 23 +-
.../data/datasets/bert_dataset.py | 4 +-
.../data/datasets/data_samplers.py | 8 +-
.../data/tokenizer/tokenizer.py | 30 +-
.../sequence_parallel/loss_func/bert_loss.py | 28 +-
.../loss_func/cross_entropy.py | 12 +-
.../tutorial/sequence_parallel/model/bert.py | 8 +-
.../sequence_parallel/model/layers/head.py | 25 +-
.../model/layers/preprocess.py | 9 +-
.../tutorial/sequence_parallel/test_ci.sh | 5 +-
examples/tutorial/sequence_parallel/train.py | 9 +-
tests/components_to_test/resnet.py | 13 +-
.../test_C_solver_consistency.py | 2 +-
.../test_ckpt_torchvision.py | 2 +-
.../test_compatibility_with_gemini.py | 5 +-
.../test_autochunk_alphafold_utils.py | 2 +-
.../test_autochunk_diffuser_utils.py | 2 +-
.../test_autochunk_vit_utils.py | 2 +-
tests/test_cluster/test_process_group_mesh.py | 6 +-
.../test_context/configs/parallel_2d_init.py | 10 -
.../configs/parallel_2p5d_init.py | 11 -
.../test_context/configs/parallel_3d_init.py | 10 -
tests/test_device/test_init_logical_pg.py | 5 +-
.../test_activation_checkpoint_codegen.py | 2 +-
...st_nested_activation_checkpoint_codegen.py | 2 +-
.../test_codegen/test_offload_codegen.py | 2 +-
tests/test_fx/test_parallel_1d.py | 2 +-
.../test_pipeline/test_topo/topo_utils.py | 33 +-
.../test_amp/test_naive_fp16.py | 4 +-
.../test_amp/test_torch_fp16.py | 4 +-
.../test_comm/test_boardcast_send_recv_v2.py | 6 +-
tests/test_legacy/test_comm/test_comm.py | 6 +-
.../test_comm/test_object_list_p2p.py | 6 +-
.../test_comm/test_object_list_p2p_v2.py | 6 +-
.../test_context/configs/parallel_2d_init.py | 4 +
.../configs/parallel_2p5d_init.py | 4 +
.../test_context/configs/parallel_3d_init.py | 4 +
.../test_context/test_hybrid_parallel.py | 10 +-
.../test_data/test_cifar10_dataset.py | 0
.../test_data/test_data_parallel_sampler.py | 9 +-
.../test_deterministic_dataloader.py | 74 +++
tests/test_legacy/test_engine/test_engine.py | 20 +-
.../test_engine/test_gradient_accumluation.py | 19 +-
.../test_1d/checks_1d/check_layer_1d.py | 9 +-
.../test_layers/test_1d/test_1d.py | 4 +-
.../test_2d/checks_2d/check_layer_2d.py | 7 +-
.../test_2d/checks_2d/check_operation_2d.py | 7 +-
.../test_layers/test_2d/test_2d.py | 4 +-
.../test_2p5d/checks_2p5d/check_layer_2p5d.py | 7 +-
.../checks_2p5d/check_operation_2p5d.py | 7 +-
.../test_layers/test_2p5d/test_2p5d.py | 4 +-
.../test_3d/checks_3d/check_layer_3d.py | 7 +-
.../test_layers/test_3d/test_3d.py | 4 +-
.../test_layers/test_cache_embedding.py | 5 +-
.../checks_seq/check_layer_seq.py | 4 +-
.../test_sequence/test_sequence.py | 6 +-
.../test_pipeline/rpc_test_utils.py | 4 +-
.../test_pipeline/test_cuda_rpc_chimera.py | 6 +-
.../test_pipeline/test_cuda_rpc_optimizer.py | 9 +-
.../test_pipeline/test_cuda_rpc_pipeline.py | 4 +-
.../test_cuda_rpc_value_correctness.py | 7 +-
.../test_pipeline/test_middleware_1f1b.py | 8 +-
.../test_pipeline/test_pipelinable.py | 2 +-
.../test_pipeline_process_group.py | 4 +-
.../test_tensor/common_utils/__init__.py | 2 +-
.../test_tensor/common_utils/_utils.py | 6 +-
.../test_tensor/core/test_dist_spec_mgr.py | 4 +-
.../test_tensor/test_parameter.py | 4 +-
.../test_trainer/test_pipeline/test_p2p.py | 6 +-
.../test_pipeline/test_pipeline_schedule.py | 10 +-
.../test_trainer_with_non_pipe_schedule.py | 17 +-
.../test_trainer_with_pipe_schedule.py | 22 +-
.../test_activation_checkpointing.py | 6 +-
.../test_checkpoint/test_checkpoint_1d.py | 12 +-
.../test_checkpoint/test_checkpoint_2d.py | 12 +-
.../test_checkpoint/test_checkpoint_2p5d.py | 12 +-
.../test_checkpoint/test_checkpoint_3d.py | 12 +-
.../test_utils/test_memory.py | 4 +-
.../test_utils/test_norm_gradient_clipping.py | 6 +-
.../test_zero}/test_commons.py | 6 +-
tests/test_moe/test_kernel.py | 4 +-
tests/test_moe/test_moe_zero_optim.py | 2 +-
tests/test_tensor/test_comm_spec_apply.py | 5 +-
.../test_dtensor/test_comm_spec.py | 6 +-
tests/test_tensor/test_mix_gather.py | 4 +-
.../test_zero_gradient_clippling.py | 111 ----
.../test_zero/test_gemini/test_chunk_mgrv2.py | 2 -
tests/test_zero/test_gemini/test_fwd_bwd.py | 4 +-
.../test_gemini/test_gemini_use_rmt.py | 2 +-
tests/test_zero/test_gemini/test_grad_clip.py | 4 +-
tests/test_zero/test_gemini/test_inference.py | 4 +-
tests/test_zero/test_gemini/test_optim.py | 4 +-
.../test_gemini/test_zeroddp_state_dict.py | 2 +-
.../test_gemini/test_zerooptim_state_dict.py | 2 +-
.../test_zero/test_low_level/test_zero_tp.py | 96 ---
342 files changed, 2917 insertions(+), 4180 deletions(-)
delete mode 100644 colossalai/cli/benchmark/__init__.py
delete mode 100644 colossalai/cli/benchmark/benchmark.py
delete mode 100644 colossalai/cli/benchmark/models.py
delete mode 100644 colossalai/cli/benchmark/utils.py
delete mode 100644 colossalai/core.py
create mode 100644 colossalai/legacy/amp/__init__.py
rename colossalai/{ => legacy}/amp/amp_type.py (100%)
rename colossalai/{ => legacy}/amp/apex_amp/__init__.py (100%)
rename colossalai/{ => legacy}/amp/apex_amp/apex_amp.py (86%)
create mode 100644 colossalai/legacy/amp/naive_amp/__init__.py
rename colossalai/{ => legacy}/amp/naive_amp/_fp16_optimizer.py (97%)
rename colossalai/{ => legacy}/amp/naive_amp/_utils.py (100%)
rename colossalai/{ => legacy}/amp/naive_amp/naive_amp.py (94%)
rename colossalai/{ => legacy}/amp/torch_amp/__init__.py (100%)
rename colossalai/{ => legacy}/amp/torch_amp/_grad_scaler.py (99%)
rename colossalai/{ => legacy}/amp/torch_amp/torch_amp.py (95%)
rename colossalai/{ => legacy}/constants.py (100%)
create mode 100644 colossalai/legacy/context/__init__.py
rename colossalai/{ => legacy}/context/parallel_context.py (88%)
rename colossalai/{ => legacy}/context/parallel_mode.py (100%)
rename colossalai/{ => legacy}/context/process_group_initializer/__init__.py (100%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_1d.py (96%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_2d.py (98%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_2p5d.py (99%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_3d.py (99%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_data.py (100%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_model.py (100%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_pipeline.py (100%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_sequence.py (100%)
rename colossalai/{ => legacy}/context/process_group_initializer/initializer_tensor.py (100%)
rename colossalai/{ => legacy}/context/process_group_initializer/process_group_initializer.py (100%)
rename colossalai/{ => legacy}/context/random/__init__.py (100%)
rename colossalai/{ => legacy}/context/random/_helper.py (90%)
rename colossalai/{ => legacy}/context/random/seed_manager.py (86%)
create mode 100644 colossalai/legacy/core.py
rename colossalai/{ => legacy}/global_variables.py (100%)
create mode 100644 colossalai/legacy/initialize.py
delete mode 100644 colossalai/legacy/nn/_ops/addmm.py
delete mode 100644 colossalai/legacy/nn/_ops/batch_norm.py
delete mode 100644 colossalai/legacy/nn/_ops/element_wise.py
delete mode 100644 colossalai/legacy/nn/_ops/embedding.py
delete mode 100644 colossalai/legacy/nn/_ops/embedding_bag.py
delete mode 100644 colossalai/legacy/nn/_ops/layernorm.py
delete mode 100644 colossalai/legacy/nn/_ops/linear.py
delete mode 100644 colossalai/legacy/nn/_ops/loss.py
delete mode 100644 colossalai/legacy/nn/_ops/view.py
create mode 100644 colossalai/legacy/pipeline/__init__.py
rename colossalai/{ => legacy}/pipeline/layer_spec.py (97%)
create mode 100644 colossalai/legacy/pipeline/middleware/__init__.py
rename colossalai/{ => legacy}/pipeline/middleware/adaptor/__init__.py (62%)
rename colossalai/{ => legacy}/pipeline/middleware/adaptor/fx.py (92%)
rename colossalai/{ => legacy}/pipeline/middleware/topo.py (95%)
rename colossalai/{ => legacy}/pipeline/pipelinable.py (93%)
rename colossalai/{ => legacy}/pipeline/pipeline_process_group.py (98%)
create mode 100644 colossalai/legacy/pipeline/rpc/__init__.py
rename colossalai/{ => legacy}/pipeline/rpc/_pipeline_base.py (99%)
rename colossalai/{ => legacy}/pipeline/rpc/_pipeline_schedule.py (97%)
rename colossalai/{ => legacy}/pipeline/rpc/utils.py (98%)
rename colossalai/{ => legacy}/pipeline/utils.py (100%)
create mode 100644 colossalai/legacy/tensor/__init__.py
rename colossalai/{ => legacy}/tensor/compute_spec.py (100%)
rename colossalai/{ => legacy}/tensor/const.py (100%)
rename colossalai/{ => legacy}/tensor/dist_spec_mgr.py (97%)
rename colossalai/{ => legacy}/tensor/distspec.py (100%)
rename colossalai/{ => legacy}/tensor/op_wrapper.py (97%)
rename colossalai/{ => legacy}/tensor/process_group.py (100%)
rename colossalai/{ => legacy}/tensor/tensor_spec.py (79%)
create mode 100644 colossalai/legacy/utils/__init__.py
rename colossalai/{ => legacy}/utils/activation_checkpoint.py (95%)
create mode 100644 colossalai/legacy/utils/checkpoint/__init__.py
rename colossalai/{ => legacy}/utils/checkpoint/module_checkpoint.py (90%)
rename colossalai/{ => legacy}/utils/checkpoint/utils.py (91%)
rename colossalai/{ => legacy}/utils/checkpointing.py (98%)
create mode 100644 colossalai/legacy/utils/common.py
rename colossalai/{ => legacy}/utils/data_sampler/__init__.py (100%)
rename colossalai/{ => legacy}/utils/data_sampler/base_sampler.py (100%)
rename colossalai/{ => legacy}/utils/data_sampler/data_parallel_sampler.py (98%)
rename colossalai/{ => legacy}/utils/memory.py (95%)
rename colossalai/{ => legacy}/utils/profiler/__init__.py (100%)
rename colossalai/{ => legacy}/utils/profiler/extention.py (100%)
rename colossalai/{ => legacy}/utils/profiler/legacy/__init__.py (77%)
rename colossalai/{ => legacy}/utils/profiler/legacy/comm_profiler.py (96%)
rename colossalai/{ => legacy}/utils/profiler/legacy/pcie_profiler.py (95%)
rename colossalai/{ => legacy}/utils/profiler/legacy/prof_utils.py (94%)
rename colossalai/{ => legacy}/utils/profiler/profiler.py (97%)
rename colossalai/{ => legacy}/utils/profiler/stateful_tensor_mem_extention.py (98%)
rename colossalai/{zero/legacy => legacy/zero}/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/gemini_context.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/_shard_grad_ophook.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/_shard_param_ophook.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/runtime_mem_tracer_hook.py (98%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/ophooks/utils.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/paramhooks/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/paramhooks/_param_hookmgr.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/stateful_tensor.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/stateful_tensor_mgr.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/tensor_placement_policy.py (98%)
rename colossalai/{zero/legacy => legacy/zero}/gemini/tensor_utils.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/init_ctx/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/init_ctx/init_context.py (96%)
rename colossalai/{zero/legacy => legacy/zero}/shard_utils/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/shard_utils/base_shard_strategy.py (90%)
rename colossalai/{zero/legacy => legacy/zero}/shard_utils/bucket_tensor_shard_strategy.py (97%)
rename colossalai/{zero/legacy => legacy/zero}/shard_utils/commons.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/shard_utils/tensor_shard_strategy.py (90%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_model/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_model/_utils.py (97%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_model/reduce_scatter.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_model/sharded_model_v2.py (97%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_model/utils.py (92%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_model/zero_hook.py (94%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_optim/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_optim/sharded_optim_v2.py (97%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_param/__init__.py (100%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_param/sharded_param.py (96%)
rename colossalai/{zero/legacy => legacy/zero}/sharded_param/sharded_tensor.py (94%)
delete mode 100644 colossalai/nn/optimizer/colossalai_optimizer.py
delete mode 100644 colossalai/pipeline/middleware/__init__.py
delete mode 100644 colossalai/pipeline/rpc/__init__.py
delete mode 100644 colossalai/utils/checkpoint/__init__.py
delete mode 100644 tests/test_context/configs/parallel_2d_init.py
delete mode 100644 tests/test_context/configs/parallel_2p5d_init.py
delete mode 100644 tests/test_context/configs/parallel_3d_init.py
rename tests/{ => test_legacy}/test_amp/test_naive_fp16.py (94%)
rename tests/{ => test_legacy}/test_amp/test_torch_fp16.py (95%)
create mode 100644 tests/test_legacy/test_context/configs/parallel_2d_init.py
create mode 100644 tests/test_legacy/test_context/configs/parallel_2p5d_init.py
create mode 100644 tests/test_legacy/test_context/configs/parallel_3d_init.py
rename tests/{ => test_legacy}/test_context/test_hybrid_parallel.py (95%)
rename tests/{ => test_legacy}/test_data/test_cifar10_dataset.py (100%)
rename tests/{ => test_legacy}/test_data/test_data_parallel_sampler.py (87%)
create mode 100644 tests/test_legacy/test_data/test_deterministic_dataloader.py
rename tests/{ => test_legacy}/test_pipeline/rpc_test_utils.py (97%)
rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_chimera.py (94%)
rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_optimizer.py (89%)
rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_pipeline.py (87%)
rename tests/{ => test_legacy}/test_pipeline/test_cuda_rpc_value_correctness.py (91%)
rename tests/{ => test_legacy}/test_pipeline/test_middleware_1f1b.py (94%)
rename tests/{ => test_legacy}/test_pipeline/test_pipelinable.py (96%)
rename tests/{ => test_legacy}/test_pipeline/test_pipeline_process_group.py (91%)
rename tests/{ => test_legacy}/test_tensor/common_utils/__init__.py (95%)
rename tests/{ => test_legacy}/test_tensor/common_utils/_utils.py (93%)
rename tests/{ => test_legacy}/test_tensor/core/test_dist_spec_mgr.py (91%)
rename tests/{ => test_legacy}/test_tensor/test_parameter.py (82%)
rename tests/{ => test_legacy}/test_utils/test_activation_checkpointing.py (94%)
rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_1d.py (83%)
rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_2d.py (83%)
rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_2p5d.py (84%)
rename tests/{ => test_legacy}/test_utils/test_checkpoint/test_checkpoint_3d.py (83%)
rename tests/{ => test_legacy}/test_utils/test_memory.py (76%)
rename tests/{ => test_legacy}/test_utils/test_norm_gradient_clipping.py (91%)
rename tests/{test_utils => test_legacy/test_zero}/test_commons.py (82%)
delete mode 100644 tests/test_utils/test_zero_gradient_clippling.py
delete mode 100644 tests/test_zero/test_low_level/test_zero_tp.py
diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml
index a3df2c50e6d3..f1e7a2d0cab0 100644
--- a/.github/workflows/doc_test_on_pr.yml
+++ b/.github/workflows/doc_test_on_pr.yml
@@ -89,7 +89,7 @@ jobs:
- name: Install ColossalAI
run: |
source activate pytorch
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Test the Doc
run: |
diff --git a/.github/workflows/doc_test_on_schedule.yml b/.github/workflows/doc_test_on_schedule.yml
index 6b4f5d1f908c..027fbfd0aaeb 100644
--- a/.github/workflows/doc_test_on_schedule.yml
+++ b/.github/workflows/doc_test_on_schedule.yml
@@ -32,7 +32,7 @@ jobs:
- name: Install ColossalAI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Install Doc Test Requirements
run: |
diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml
index 620d4771af55..9d3bd9a48235 100644
--- a/.github/workflows/example_check_on_dispatch.yml
+++ b/.github/workflows/example_check_on_dispatch.yml
@@ -53,7 +53,7 @@ jobs:
uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Test the example
run: |
dir=${{ matrix.directory }}
diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml
index ec23b9d1c59f..5934704f4102 100644
--- a/.github/workflows/example_check_on_pr.yml
+++ b/.github/workflows/example_check_on_pr.yml
@@ -88,7 +88,7 @@ jobs:
- name: Install Colossal-AI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Test the example
run: |
diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml
index bd52ca4321a2..5ed128c3ebc5 100644
--- a/.github/workflows/example_check_on_schedule.yml
+++ b/.github/workflows/example_check_on_schedule.yml
@@ -42,7 +42,7 @@ jobs:
- name: Install Colossal-AI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Traverse all files
run: |
diff --git a/colossalai/__init__.py b/colossalai/__init__.py
index f859161f7810..fa6f72a605c0 100644
--- a/colossalai/__init__.py
+++ b/colossalai/__init__.py
@@ -1,11 +1,4 @@
-from .initialize import (
- get_default_parser,
- initialize,
- launch,
- launch_from_openmpi,
- launch_from_slurm,
- launch_from_torch,
-)
+from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
try:
# .version will be created by setup.py
@@ -15,3 +8,5 @@
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice
__version__ = '0.0.0'
print('please install Colossal-AI from https://www.colossalai.org/download or from source')
+
+__all__ = ['launch', 'launch_from_openmpi', 'launch_from_slurm', 'launch_from_torch', '__version__']
diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py
index 963215476b6b..e69de29bb2d1 100644
--- a/colossalai/amp/__init__.py
+++ b/colossalai/amp/__init__.py
@@ -1,54 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-from torch.optim import Optimizer
-
-from colossalai.context import Config
-
-from .amp_type import AMP_TYPE
-from .apex_amp import convert_to_apex_amp
-from .naive_amp import convert_to_naive_amp
-from .torch_amp import convert_to_torch_amp
-
-__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
-
-
-def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
- """A helper function to wrap training components with Torch AMP modules.
-
- Args:
- param model (:class:`torch.nn.Module`): your model object.
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
- criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
- mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
- amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
-
- Returns:
- A tuple (model, optimizer, criterion).
-
- Note:
- ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
- for more details about ``amp_config``.
- For ``apex_amp``, please check
- `apex_amp config `_.
- For ``naive_amp``, please check
- `naive_amp config `_.
- For ``torch_amp``, please check
- `torch_amp config `_.
- """
- assert isinstance(mode, AMP_TYPE), \
- f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
-
- if amp_config is None:
- amp_config = Config()
-
- if mode == AMP_TYPE.TORCH:
- model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
- elif mode == AMP_TYPE.APEX:
- model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
- elif mode == AMP_TYPE.NAIVE:
- model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
-
- return model, optimizer, criterion
diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py
index 5b2f71d3ced7..e69de29bb2d1 100644
--- a/colossalai/amp/naive_amp/__init__.py
+++ b/colossalai/amp/naive_amp/__init__.py
@@ -1,60 +0,0 @@
-import inspect
-
-import torch.nn as nn
-from torch.optim import Optimizer
-
-from colossalai.utils import is_no_pp_or_last_stage
-
-from ._fp16_optimizer import FP16Optimizer
-from .grad_scaler import ConstantGradScaler, DynamicGradScaler
-from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
-
-
-def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
- """A helper function to wrap training components with naive AMP modules. In this mode,
- we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
- which is equivalent to Apex O3.
-
- Args:
- model (:class:`torch.nn.Module`): your model object
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object
- amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
-
- Returns:
- Tuple: A tuple (model, optimizer)
-
- The ``amp_config`` should contain parameters below::
-
- verbose (bool, optional): if set to `True`, will print debug info (Default: False).
- clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
- Note that clipping is ignored if clip_grad == 0.
- dynamic_grad_scale (bool): whether to use dynamic grad scaler.
- """
- if isinstance(model, nn.ModuleList):
- # interleaved pipeline
- module_list = []
- for chunk, m in enumerate(model):
- output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
- module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
- model = nn.ModuleList(module_list)
- else:
- output_to_fp32 = is_no_pp_or_last_stage()
- model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
-
- use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
- if use_dynamic_grad_scaler:
- scaler_class = DynamicGradScaler
- else:
- scaler_class = ConstantGradScaler
-
- sig = inspect.signature(scaler_class.__init__)
- kwargs = dict()
- for param in sig.parameters.values():
- if param.name in amp_config:
- kwargs[param.name] = amp_config.pop(param.name)
- grad_scaler = scaler_class(**kwargs)
- optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
- return model, optimizer
-
-
-__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
index 19d85b80dd3d..353133bd6f2d 100644
--- a/colossalai/auto_parallel/offload/amp_optimizer.py
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -5,8 +5,8 @@
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
+from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
@@ -19,7 +19,7 @@ class OptimState(Enum):
UNSCALED = 1
-class AMPOptimizer(ColossalaiOptimizer):
+class AMPOptimizer(OptimizerWrapper):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 3441eca38ce7..664ac63e45ac 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -13,7 +13,6 @@
from torch.optim import Optimizer
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@@ -130,10 +129,7 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
This method should be used before saving/loading it to/from sharded checkpoints.
'''
- # TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future
unwrapped_optim = optimizer.optim
- if isinstance(unwrapped_optim, ColossalaiOptimizer):
- unwrapped_optim = unwrapped_optim.optim
return unwrapped_optim
diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py
deleted file mode 100644
index 618ff8c61dd4..000000000000
--- a/colossalai/cli/benchmark/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import click
-
-from colossalai.context import Config
-
-from .benchmark import run_benchmark
-from .utils import *
-
-__all__ = ['benchmark']
-
-
-@click.command()
-@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.")
-@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.")
-@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.")
-@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.")
-@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.")
-@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.")
-@click.option("-l", "--layers", type=int, default=2)
-@click.option("-m",
- "--model",
- type=click.Choice(['mlp'], case_sensitive=False),
- default='mlp',
- help="Select the model to benchmark, currently only supports MLP")
-def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int,
- layers: int, model: str):
- args_dict = locals()
- args = Config(args_dict)
- run_benchmark(args)
diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py
deleted file mode 100644
index 97a9f45722dd..000000000000
--- a/colossalai/cli/benchmark/benchmark.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from functools import partial
-from typing import Dict, List
-
-import click
-import torch.multiprocessing as mp
-
-import colossalai
-from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model
-from colossalai.context import Config
-from colossalai.context.random import reset_seeds
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.testing import free_port
-from colossalai.utils import MultiTimer
-
-from .models import MLP
-
-
-def run_benchmark(args: Config) -> None:
- """
- Run benchmarking with torch.multiprocessing.
- """
-
- # sanity checks
- if args.gpus is None:
- click.echo("Error: --num_gpus is not given")
- exit()
- if args.gpus <= 1:
- click.echo("Warning: tensor parallel will be activated with at least 2 devices.")
-
- click.echo("=== Benchmarking Parameters ===")
- for k, v in args.items():
- click.echo(f'{k}: {v}')
- click.echo('')
-
- config_list = find_all_configs(args.gpus)
-
- avail_ports = [free_port() for _ in range(len(config_list))]
- run_func = partial(run_dist_profiling,
- world_size=args.gpus,
- port_list=avail_ports,
- config_list=config_list,
- hyperparams=args)
- mp.spawn(run_func, nprocs=args.gpus)
-
-
-def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
- hyperparams: Config) -> None:
- """
- A function executed for profiling, this function should be spawn by torch.multiprocessing.
-
- Args:
- rank (int): rank of the process
- world_size (int): the number of processes
- port_list (List[int]): a list of free ports for initializing distributed networks
- config_list (List[Dict]): a list of configuration
- hyperparams (Config): the hyperparameters given by the user
-
- """
-
- # disable logging for clean output
- disable_existing_loggers()
- logger = get_dist_logger()
- logger.set_level('WARNING')
-
- for config, port in zip(config_list, port_list):
- colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- timer = MultiTimer()
-
- # 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size.
- if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0:
- click.echo(
- "1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size."
- )
- continue
-
- if hyperparams.model == 'mlp':
- model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
- else:
- if gpc.get_global_rank() == 0:
- click.echo("Error: Invalid argument for --model")
- exit()
-
- data_func = partial(get_batch_data,
- dim=hyperparams.dimension,
- batch_size=hyperparams.batch_size,
- seq_length=hyperparams.seq_len,
- mode=config.parallel.tensor.mode)
-
- fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
- warmup_steps=hyperparams.warmup_steps,
- profile_steps=hyperparams.profile_steps,
- data_func=data_func,
- timer=timer)
-
- gpc.destroy()
- reset_seeds()
-
- if gpc.get_global_rank() == 0:
- config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
- click.echo(f"=== {config_str} ===")
- click.echo(f"Average forward time: {fwd_time}")
- click.echo(f"Average backward time: {bwd_time}")
- click.echo(f"Max allocated GPU memory: {max_allocated}")
- click.echo(f"Max cached GPU memory: {max_cached}\n")
diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py
deleted file mode 100644
index 385b485b6016..000000000000
--- a/colossalai/cli/benchmark/models.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-
-import colossalai.legacy.nn as col_nn
-
-
-class MLP(torch.nn.Module):
-
- def __init__(self, dim: int, layers: int):
- super().__init__()
- self.layers = torch.nn.ModuleList()
-
- for _ in range(layers):
- self.layers.append(col_nn.Linear(dim, dim))
-
- def forward(self, x):
- for layer in self.layers:
- x = layer(x)
- return x
diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py
deleted file mode 100644
index ee7d92d6ea6a..000000000000
--- a/colossalai/cli/benchmark/utils.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import math
-import time
-from typing import Callable, Dict, List, Tuple
-
-import torch
-
-from colossalai.context import Config, ParallelMode
-from colossalai.utils import MultiTimer
-
-
-def get_time_stamp() -> int:
- """
- Return the time stamp for profiling.
-
- Returns:
- time_stamp (int): the time given by time.time()
- """
-
- torch.cuda.synchronize()
- time_stamp = time.time()
- return time_stamp
-
-
-def get_memory_states() -> Tuple[float]:
- """
- Return the memory statistics.
-
- Returns:
- max_allocated (float): the allocated CUDA memory
- max_cached (float): the cached CUDA memory
- """
-
- max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
- max_cached = torch.cuda.max_memory_reserved() / (1024**3)
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.empty_cache()
- return max_allocated, max_cached
-
-
-def find_all_configs(device_cnt: int) -> List[Dict]:
- """
- Find all possible configurations for tensor parallelism
-
- Args:
- device_cnt (int): the number of devices
-
- Returns:
- config_list (List[Dict]): a list of configurations
- """
-
- def _is_square(num):
- # 2D parallel should be implemented with at least 2 devices.
- if num <= 1:
- return False
- return math.floor(math.sqrt(num))**2 == num
-
- def _is_cube(num):
- # 3D parallel should be implemented with at least 2 devices.
- if num <= 1:
- return False
- return math.floor(num**(1. / 3.))**3 == num
-
- config_list = []
-
- # add non-parallel config
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
- config_list.append(config)
-
- # add 1D config
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
- config_list.append(config)
-
- # add 2D config only if device_cnt is a square
- if _is_square(device_cnt):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
- config_list.append(config)
-
- # check for 2.5D
- # iterate over depth
- for depth in range(1, device_cnt):
- if device_cnt % depth == 0 and _is_square(device_cnt // depth):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
- config_list.append(config)
-
- # check for 3D if device_cnt is a cube
- if _is_cube(device_cnt):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
- config_list.append(config)
-
- config_list = [Config(cfg) for cfg in config_list]
- return config_list
-
-
-def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
- timer: MultiTimer) -> Tuple[float]:
- """
- Profile the forward and backward of a model
-
- Args:
- model (torch.nn.Module): a PyTorch model
- warmup_steps (int): the number of steps for warmup
- profile_steps (int): the number of steps for profiling
- data_func (Callable): a function to generate random data
- timer (colossalai.utils.Multitimer): a timer instance for time recording
-
- Returns:
- fwd_time (float): the average forward time taken by forward pass in second
- bwd_time (float): the average backward time taken by forward pass in second
- max_allocated (float): the maximum GPU memory allocated in GB
- max_cached (float): the maximum GPU memory cached in GB
- """
-
- def _run_step(data):
- timer.start('forward')
- out = model(data)
- timer.stop('forward', keep_in_history=True)
- timer.start('backward')
- out.mean().backward()
- timer.stop('backward', keep_in_history=True)
-
- data_list = [data_func() for _ in range(warmup_steps)]
- for data in data_list:
- _run_step(data)
- timer.reset('forward')
- timer.reset('backward')
-
- for _ in range(profile_steps):
- data = data_func()
- _run_step(data)
-
- max_allocated, max_cached = get_memory_states()
- fwd_time = timer.get_timer('forward').get_history_mean()
- bwd_time = timer.get_timer('backward').get_history_mean()
- return fwd_time, bwd_time, max_allocated, max_cached
-
-
-def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
- """
- Return a random data of shape (batch_size, seq_length, dim) for profiling.
-
- Args:
- dim (int): hidden size
- batch_size (int): the number of data samples
- seq_length (int): the number of tokens
- mode (ParallelMode): Colossal-AI ParallelMode enum
-
- Returns:
- data (torch.Tensor): random data
- """
-
- if mode in ['2d', '2.5d']:
- batch_size = batch_size // 2
- dim = dim // 2
- elif mode == '3d':
- batch_size = batch_size // 4
- dim = dim // 2
-
- data = torch.rand(batch_size, seq_length, dim).cuda()
- return data
diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py
index a94e1150e49f..0dea7c504957 100644
--- a/colossalai/cli/cli.py
+++ b/colossalai/cli/cli.py
@@ -1,6 +1,5 @@
import click
-from .benchmark import benchmark
from .check import check
from .launcher import run
@@ -19,7 +18,6 @@ def cli():
cli.add_command(run)
cli.add_command(check)
-cli.add_command(benchmark)
if __name__ == '__main__':
cli()
diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py
index 50178b5fa850..eb6d5d05a008 100644
--- a/colossalai/context/__init__.py
+++ b/colossalai/context/__init__.py
@@ -1,6 +1,8 @@
from .config import Config, ConfigException
-from .parallel_context import ParallelContext
-from .parallel_mode import ParallelMode
-from .moe_context import MOE_CONTEXT
-from .process_group_initializer import *
-from .random import *
+
+# from .moe_context import MOE_CONTEXT
+
+__all__ = [
+ 'Config',
+ 'ConfigException',
+]
diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py
index b41f4072a405..b6e3b52017b2 100644
--- a/colossalai/context/moe_context.py
+++ b/colossalai/context/moe_context.py
@@ -3,13 +3,12 @@
import torch
import torch.distributed as dist
-from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
-from colossalai.tensor import ProcessGroup
+from colossalai.legacy.tensor import ProcessGroup
def _check_sanity():
- from colossalai.core import global_context as gpc
+ from colossalai.legacy.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or "
"pipeline parallel at present.")
@@ -61,7 +60,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True):
self.world_size = dist.get_world_size()
- from colossalai.core import global_context as gpc
+ from colossalai.legacy.core import global_context as gpc
self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
assert self.world_size % self.max_ep_size == 0, \
"Maximum expert parallel size must be a factor of the number of GPUs"
diff --git a/colossalai/core.py b/colossalai/core.py
deleted file mode 100644
index 153247bbed9c..000000000000
--- a/colossalai/core.py
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from colossalai.context.parallel_context import global_context
-
-__all__ = ['global_context']
\ No newline at end of file
diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py
index d2bad06bb45a..ccbab0c38a29 100644
--- a/colossalai/fx/passes/shard_1d_pass.py
+++ b/colossalai/fx/passes/shard_1d_pass.py
@@ -1,9 +1,11 @@
+import operator
+
import torch
import torch.nn as nn
-import operator
-from colossalai.tensor import ProcessGroup
-from colossalai.tensor.distspec import ShardSpec
-from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
+
+from colossalai.legacy.tensor import ProcessGroup
+from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec
+from colossalai.legacy.tensor.distspec import ShardSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
@@ -13,7 +15,7 @@
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
- """weight_split
+ """weight_split
split a nn.Parameter
Args:
@@ -60,9 +62,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
- This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
+ This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
- #TODO: Needs to handle special cases, like x = linear(x) + linear(x)
+ # TODO: Needs to handle special cases, like x = linear(x) + linear(x)
graph = graph_module.graph
world_size = process_group.world_size()
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index a1694e059fb4..b8718abc80bd 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -1,58 +1,17 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import argparse
import os
-import pprint
+import warnings
from pathlib import Path
-from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Dict, Union
import torch
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim.lr_scheduler import _LRScheduler
-from torch.optim.optimizer import Optimizer
-from torch.utils.data import DataLoader
+import torch.distributed as dist
-from colossalai.amp import AMP_TYPE, convert_to_amp
-from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.context import Config, ConfigException, ParallelMode
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.core import global_context as gpc
-from colossalai.legacy.builder.builder import build_gradient_handler
-from colossalai.legacy.engine import Engine
-from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
-from colossalai.legacy.engine.schedule import (
- InterleavedPipelineSchedule,
- NonPipelineSchedule,
- PipelineSchedule,
- get_tensor_shape,
-)
+from colossalai.context import Config
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
-from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
-from colossalai.utils.moe import sync_moe_model_param
-from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
-from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
-
-
-def get_default_parser():
- """Reads user command line and uses an argument parser to parse the input arguments.
- Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
-
- Returns:
- Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
- """
- parser = argparse.ArgumentParser()
- parser.add_argument('--config', type=str, help='path to the config file')
- parser.add_argument('--host', type=str, help='the master address for distributed training')
- parser.add_argument('--port', type=int, help='the master port for distributed training')
- parser.add_argument('--world_size', type=int, help='world size for distributed training')
- parser.add_argument('--rank', type=int, help='rank for the default process group')
- parser.add_argument('--local_rank', type=int, help='local rank on the node')
- parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
- return parser
+from colossalai.utils import set_device, set_seed
def launch(config: Union[str, Path, Config, Dict],
@@ -83,40 +42,23 @@ def launch(config: Union[str, Path, Config, Dict],
Raises:
Exception: Raise exception when config type is wrong
"""
- gpc.verbose = verbose
-
- # set config
- assert isinstance(config, (Config, str, Path, dict)), \
- f'expected argument config to be Config, str or Path, but got {type(config)}'
- if not isinstance(config, Config) and isinstance(config, dict):
- config = Config(config)
- if isinstance(config, (str, Path)):
- config = Config.from_file(config)
- gpc.load_config(config)
+ if rank == 0:
+ warnings.warn("`config` is deprecated and will be removed soon.")
# init default process group
- gpc.init_global_dist(rank, world_size, backend, host, port)
-
- # init process groups for different parallel modes from config
- gpc.init_parallel_groups()
+ init_method = f'tcp://[{host}]:{port}'
+ dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
if torch.cuda.is_available():
# if local rank is not given, calculate automatically
- gpc.set_device(local_rank)
-
- # set the number of processes running on the same node
- gpc.detect_num_processes_on_current_node()
+ set_device(local_rank)
- gpc.set_seed(seed)
+ set_seed(seed)
if verbose:
logger = get_dist_logger()
- logger.info(
- f'Distributed environment is initialized, '
- f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
- f'tensor parallel size: {gpc.tensor_parallel_size}',
- ranks=[0])
+ logger.info(f'Distributed environment is initialized, world size: {dist.get_world_size()}', ranks=[0])
def launch_from_slurm(config: Union[str, Path, Config, Dict],
@@ -224,247 +166,3 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
backend=backend,
seed=seed,
verbose=verbose)
-
-
-def initialize(model: nn.Module,
- optimizer: Optimizer,
- criterion: Optional[_Loss] = None,
- train_dataloader: Optional[Iterable] = None,
- test_dataloader: Optional[Iterable] = None,
- lr_scheduler: Optional[_LRScheduler] = None,
- ophooks: Optional[List[BaseOpHook]] = None,
- verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
- """Core function to wrap the essential training components with our functionality based on the config which is
- loaded into gpc.config.
-
- Args:
- model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model.
- optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
- Your optimizer instance.
- criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
- train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
- test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
- lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
- verbose (bool, optional): Whether to print logs.
-
- Returns:
- Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
- A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
- where only ``engine`` could not be None.
- """
- # get logger
- logger = get_dist_logger()
- gpc.verbose = verbose
-
- # get config from gpc
- config = gpc.config
-
- # print config
- if verbose:
- logger.info(
- f"\n========== Your Config ========\n"
- f"{pprint.pformat(gpc.config)}\n"
- f"================================\n",
- ranks=[0])
-
- # cudnn
- cudnn_benchmark = config.get('cudnn_benchmark', False)
- cudnn_deterministic = config.get('cudnn_deterministic', False)
- torch.backends.cudnn.benchmark = cudnn_benchmark
- torch.backends.cudnn.deterministic = cudnn_deterministic
- if verbose:
- logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
-
- # zero
- use_zero = hasattr(gpc.config, 'zero')
- if use_zero:
- zero_cfg = gpc.config.get('zero', None)
- if zero_cfg is not None:
- cfg_ = zero_cfg.copy()
- else:
- cfg_ = {}
- optimizer_config = zero_cfg.get('optimizer_config', None)
- model_config = zero_cfg.get('model_config', None)
- model, optimizer = convert_to_zero_v2(model,
- optimizer,
- model_config=model_config,
- optimizer_config=optimizer_config)
-
- logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
- else:
- if isinstance(model, nn.Module):
- # first sync model across dp ranks
- model.to(get_current_device())
- elif isinstance(model, Callable):
- model = model().to(get_current_device())
-
- # optimizer maybe a optimizer_cls
- if isinstance(optimizer, Callable):
- optimizer = optimizer(model.parameters())
- logger.warning("Initializing an non ZeRO model with optimizer class")
-
- if not use_zero:
- if is_using_sequence():
- sync_model_param(model, ParallelMode.SEQUENCE_DP)
- elif MOE_CONTEXT.is_initialized:
- sync_moe_model_param(model)
- elif is_using_ddp():
- sync_model_param(model, ParallelMode.DATA)
- else:
- logger.warning(
- "The parameters of models is not automatically synchronized.\n"
- "Please make sure that all parameters are the same in data parallel group.",
- ranks=[0])
-
- # check amp and zero
- fp16_cfg = gpc.config.get('fp16', None)
-
- if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
- raise ConfigException(
- "It is not allowed to set fp16 and zero configuration in your config file at the same time")
-
- # clip grad norm
- clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
-
- # initialize amp
- amp_mode = None
- if fp16_cfg is not None and fp16_cfg.mode is not None:
- cfg_ = fp16_cfg.copy()
- amp_mode = cfg_.pop('mode')
- if is_using_pp():
- assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
- if amp_mode == AMP_TYPE.NAIVE:
- cfg_['clip_grad_norm'] = clip_grad_norm
- model, optimizer, criterion = convert_to_amp(model=model,
- optimizer=optimizer,
- criterion=criterion,
- mode=amp_mode,
- amp_config=cfg_)
-
- # get torch ddp config
- torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
-
- # gradient handler
- gradient_handler_cfg = gpc.config.get('gradient_handler', None)
- if gradient_handler_cfg is None:
- # if gradient handler is not specified in the configuration file,
- # check in the following order
- # 1. if optimizer is ZERO, then use zero grad handler
- # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
- # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
- if isinstance(optimizer, ShardedOptimizerV2):
- gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
- if verbose:
- logger.info(
- "Training with zero is detected, ZeROGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- elif is_using_ddp() and MOE_CONTEXT.is_initialized:
- gradient_handler_cfg = [dict(type='MoeGradientHandler')]
- if verbose:
- logger.info(
- "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- elif is_using_sequence():
- model = DDP(model,
- process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
- device_ids=[torch.cuda.current_device()],
- **torch_ddp_cfg)
- if verbose:
- logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
- ranks=[0])
- elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
- model = DDP(model,
- process_group=gpc.get_group(ParallelMode.DATA),
- device_ids=[torch.cuda.current_device()],
- **torch_ddp_cfg)
- if verbose:
- logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
- elif is_using_ddp():
- gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
- if verbose:
- logger.info(
- "Data parallel training is detected when using pipeline parallel, "
- "DataParallelGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- # add pipeline parallel gradient handler, if pipeline shared module is detected
- for param in model.parameters():
- if getattr(param, 'pipeline_shared_module_pg', None) is not None:
- if gradient_handler_cfg is None:
- gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
- else:
- gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
- if verbose:
- logger.info(
- "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- break
- else:
- if not isinstance(gradient_handler_cfg, list):
- raise ConfigException(
- f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
- )
-
- # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
- # to avoid duplicated buffer synchronization
- if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
- model.module.sync_buffer = False
-
- # initialize schedule for engine
- if is_using_pp():
- tensor_shape = get_tensor_shape()
- use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
- if gpc.is_initialized(ParallelMode.PARALLEL_1D):
- scatter_gather = True
- else:
- scatter_gather = False
- if use_interleaved:
- if isinstance(model, nn.Sequential):
- model = nn.ModuleList([model])
- schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
- gpc.config.model.num_chunks,
- tensor_shape=tensor_shape,
- scatter_gather_tensors=scatter_gather)
- else:
- schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
- tensor_shape=tensor_shape,
- scatter_gather_tensors=scatter_gather)
- else:
- schedule = NonPipelineSchedule()
-
- if gradient_handler_cfg is None:
- gradient_handlers = None
- if verbose and not isinstance(model, DDP):
- logger.warning(
- "No PyTorch DDP or gradient handler is set up, please make sure you do not need "
- "to all-reduce the gradients after a training step.",
- ranks=[0])
- else:
- gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
-
- # check if optimizer is ColossalaiOptimizer
- if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
- optimizer = ColossalaiOptimizer(optim=optimizer)
-
- # gradient accumulation
- grad_accum_size = gpc.config.get('gradient_accumulation', None)
- if grad_accum_size is not None:
- optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
- model=model,
- optimizer=optimizer,
- dataloader=train_dataloader,
- accumulate_size=grad_accum_size,
- gradient_handlers=gradient_handlers,
- lr_scheduler=lr_scheduler)
- engine = Engine(model=model,
- optimizer=optimizer,
- criterion=criterion,
- gradient_handlers=gradient_handlers,
- clip_grad_norm=clip_grad_norm,
- ophook_list=ophooks,
- schedule=schedule)
-
- return engine, train_dataloader, test_dataloader, lr_scheduler
diff --git a/colossalai/legacy/__init__.py b/colossalai/legacy/__init__.py
index e69de29bb2d1..f51941ee800b 100644
--- a/colossalai/legacy/__init__.py
+++ b/colossalai/legacy/__init__.py
@@ -0,0 +1,9 @@
+from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
+
+__all__ = [
+ 'launch',
+ 'launch_from_openmpi',
+ 'launch_from_slurm',
+ 'launch_from_torch',
+ 'initialize',
+]
diff --git a/colossalai/legacy/amp/__init__.py b/colossalai/legacy/amp/__init__.py
new file mode 100644
index 000000000000..e83a7f6ac5cd
--- /dev/null
+++ b/colossalai/legacy/amp/__init__.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+from torch.optim import Optimizer
+
+from colossalai.context import Config
+
+from .amp_type import AMP_TYPE
+from .apex_amp import convert_to_apex_amp
+from .naive_amp import convert_to_naive_amp
+from .torch_amp import convert_to_torch_amp
+
+__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
+
+
+def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
+ """A helper function to wrap training components with Torch AMP modules.
+
+ Args:
+ param model (:class:`torch.nn.Module`): your model object.
+ optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
+ criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
+ mode (:class:`colossalai.legacy.amp.AMP_TYPE`): amp mode.
+ amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
+
+ Returns:
+ A tuple (model, optimizer, criterion).
+
+ Note:
+ ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
+ for more details about ``amp_config``.
+ For ``apex_amp``, please check
+ `apex_amp config `_.
+ For ``naive_amp``, please check
+ `naive_amp config `_.
+ For ``torch_amp``, please check
+ `torch_amp config `_.
+ """
+ assert isinstance(mode, AMP_TYPE), \
+ f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
+
+ if amp_config is None:
+ amp_config = Config()
+
+ if mode == AMP_TYPE.TORCH:
+ model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
+ elif mode == AMP_TYPE.APEX:
+ model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
+ elif mode == AMP_TYPE.NAIVE:
+ model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
+
+ return model, optimizer, criterion
diff --git a/colossalai/amp/amp_type.py b/colossalai/legacy/amp/amp_type.py
similarity index 100%
rename from colossalai/amp/amp_type.py
rename to colossalai/legacy/amp/amp_type.py
diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/legacy/amp/apex_amp/__init__.py
similarity index 100%
rename from colossalai/amp/apex_amp/__init__.py
rename to colossalai/legacy/amp/apex_amp/__init__.py
diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/legacy/amp/apex_amp/apex_amp.py
similarity index 86%
rename from colossalai/amp/apex_amp/apex_amp.py
rename to colossalai/legacy/amp/apex_amp/apex_amp.py
index e6bdbe4520f9..acc051181562 100644
--- a/colossalai/amp/apex_amp/apex_amp.py
+++ b/colossalai/legacy/amp/apex_amp/apex_amp.py
@@ -10,11 +10,11 @@
from torch import Tensor
-from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.utils import clip_grad_norm_fp32
+from colossalai.interface import OptimizerWrapper
+from colossalai.legacy.utils import clip_grad_norm_fp32
-class ApexAMPOptimizer(ColossalaiOptimizer):
+class ApexAMPOptimizer(OptimizerWrapper):
""" A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
methods
"""
diff --git a/colossalai/legacy/amp/naive_amp/__init__.py b/colossalai/legacy/amp/naive_amp/__init__.py
new file mode 100644
index 000000000000..2ee84fc763b1
--- /dev/null
+++ b/colossalai/legacy/amp/naive_amp/__init__.py
@@ -0,0 +1,60 @@
+import inspect
+
+import torch.nn as nn
+from torch.optim import Optimizer
+
+from colossalai.amp.naive_amp.grad_scaler import ConstantGradScaler, DynamicGradScaler
+from colossalai.legacy.utils import is_no_pp_or_last_stage
+
+from ._fp16_optimizer import FP16Optimizer
+from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
+
+
+def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
+ """A helper function to wrap training components with naive AMP modules. In this mode,
+ we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
+ which is equivalent to Apex O3.
+
+ Args:
+ model (:class:`torch.nn.Module`): your model object
+ optimizer (:class:`torch.optim.Optimizer`): your optimizer object
+ amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
+
+ Returns:
+ Tuple: A tuple (model, optimizer)
+
+ The ``amp_config`` should contain parameters below::
+
+ verbose (bool, optional): if set to `True`, will print debug info (Default: False).
+ clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
+ Note that clipping is ignored if clip_grad == 0.
+ dynamic_grad_scale (bool): whether to use dynamic grad scaler.
+ """
+ if isinstance(model, nn.ModuleList):
+ # interleaved pipeline
+ module_list = []
+ for chunk, m in enumerate(model):
+ output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
+ module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
+ model = nn.ModuleList(module_list)
+ else:
+ output_to_fp32 = is_no_pp_or_last_stage()
+ model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
+
+ use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
+ if use_dynamic_grad_scaler:
+ scaler_class = DynamicGradScaler
+ else:
+ scaler_class = ConstantGradScaler
+
+ sig = inspect.signature(scaler_class.__init__)
+ kwargs = dict()
+ for param in sig.parameters.values():
+ if param.name in amp_config:
+ kwargs[param.name] = amp_config.pop(param.name)
+ grad_scaler = scaler_class(**kwargs)
+ optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
+ return model, optimizer
+
+
+__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
similarity index 97%
rename from colossalai/amp/naive_amp/_fp16_optimizer.py
rename to colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
index e4699f92b944..2733477599f7 100644
--- a/colossalai/amp/naive_amp/_fp16_optimizer.py
+++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
@@ -6,14 +6,15 @@
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
from colossalai.kernel.op_builder import FusedOptimBuilder
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
from colossalai.logging import get_dist_logger
-from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
+from colossalai.utils import multi_tensor_applier
from ._utils import has_inf_or_nan, zero_gard_by_list
-from .grad_scaler import BaseGradScaler
try:
from colossalai._C import fused_optim
diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/legacy/amp/naive_amp/_utils.py
similarity index 100%
rename from colossalai/amp/naive_amp/_utils.py
rename to colossalai/legacy/amp/naive_amp/_utils.py
diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/legacy/amp/naive_amp/naive_amp.py
similarity index 94%
rename from colossalai/amp/naive_amp/naive_amp.py
rename to colossalai/legacy/amp/naive_amp/naive_amp.py
index 6a39d518d3f4..1fab3e5a0d0d 100644
--- a/colossalai/amp/naive_amp/naive_amp.py
+++ b/colossalai/legacy/amp/naive_amp/naive_amp.py
@@ -11,14 +11,14 @@
from torch.distributed import ReduceOp
from torch.optim import Optimizer
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.interface import OptimizerWrapper
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from ._fp16_optimizer import FP16Optimizer
-class NaiveAMPOptimizer(ColossalaiOptimizer):
+class NaiveAMPOptimizer(OptimizerWrapper):
"""A wrapper class for optimizer to cast all parameters to fp16
Args:
@@ -57,7 +57,7 @@ class NaiveAMPModel(nn.Module):
Args:
model (torch.nn.Module): torch.nn.Module to be wrapped.
output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True)
- parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this module.
(Default: ``ParallelMode.DATA``)
sync_buffer (bool, optional): whether to synchronize buffer. (Default: True)
diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/legacy/amp/torch_amp/__init__.py
similarity index 100%
rename from colossalai/amp/torch_amp/__init__.py
rename to colossalai/legacy/amp/torch_amp/__init__.py
diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/legacy/amp/torch_amp/_grad_scaler.py
similarity index 99%
rename from colossalai/amp/torch_amp/_grad_scaler.py
rename to colossalai/legacy/amp/torch_amp/_grad_scaler.py
index ed4b8e484436..543dac6ab5ef 100644
--- a/colossalai/amp/torch_amp/_grad_scaler.py
+++ b/colossalai/legacy/amp/torch_amp/_grad_scaler.py
@@ -13,8 +13,8 @@
from packaging import version
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
class _MultiDeviceReplicator(object):
diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py
similarity index 95%
rename from colossalai/amp/torch_amp/torch_amp.py
rename to colossalai/legacy/amp/torch_amp/torch_amp.py
index 65718d77c2e0..c45a5956a205 100644
--- a/colossalai/amp/torch_amp/torch_amp.py
+++ b/colossalai/legacy/amp/torch_amp/torch_amp.py
@@ -7,13 +7,13 @@
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
-from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.utils import clip_grad_norm_fp32
+from colossalai.interface import OptimizerWrapper
+from colossalai.legacy.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler
-class TorchAMPOptimizer(ColossalaiOptimizer):
+class TorchAMPOptimizer(OptimizerWrapper):
"""A wrapper class which integrate Pytorch AMP with an optimizer
Args:
diff --git a/colossalai/legacy/communication/collective.py b/colossalai/legacy/communication/collective.py
index 64fb5b8b5296..7471188226f0 100644
--- a/colossalai/legacy/communication/collective.py
+++ b/colossalai/legacy/communication/collective.py
@@ -6,8 +6,8 @@
from torch import Tensor
from torch.distributed import ReduceOp
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
_all_gather_func = dist._all_gather_base \
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
@@ -26,7 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
Args:
tensor (:class:`torch.Tensor`): Tensor to be gathered.
dim (int): The dimension concatenating in.
- parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
@@ -65,7 +65,7 @@ def reduce_scatter(tensor: Tensor,
Args:
tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered.
dim (int): The dimension concatenating in.
- parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
op (torch.distributed.ReduceOp, optional): The type of reduce operation,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to
@@ -105,7 +105,7 @@ def all_reduce(tensor: Tensor,
Args:
tensor (:class:`torch.Tensor`): Tensor to be all-reduced.
- parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
op (torch.distributed.ReduceOp, optional): The type of reduce operation,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to
@@ -141,7 +141,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
Args:
tensor (:class:`torch.Tensor`): Tensor to be broadcast.
src (int): Source rank.
- parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
@@ -173,7 +173,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
Args:
tensor (:class:`torch.Tensor`): Tensor to be reduced.
dst (int): Destination rank.
- parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py
index d28d140168fd..e3f9108ab840 100644
--- a/colossalai/legacy/communication/p2p.py
+++ b/colossalai/legacy/communication/p2p.py
@@ -8,8 +8,8 @@
import torch
import torch.distributed as dist
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
diff --git a/colossalai/legacy/communication/p2p_v2.py b/colossalai/legacy/communication/p2p_v2.py
index 090311cb35f2..66af214950f2 100644
--- a/colossalai/legacy/communication/p2p_v2.py
+++ b/colossalai/legacy/communication/p2p_v2.py
@@ -10,8 +10,8 @@
from torch.distributed import ProcessGroupNCCL
from torch.distributed import distributed_c10d as c10d
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
TensorShape = Union[torch.Size, List[int], Tuple[int]]
_pg_manager = {}
diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py
index aece7574b7c4..e80192fb578d 100644
--- a/colossalai/legacy/communication/ring.py
+++ b/colossalai/legacy/communication/ring.py
@@ -3,8 +3,8 @@
import torch
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py
index 1516df356278..7e3dcf1e9820 100644
--- a/colossalai/legacy/communication/utils.py
+++ b/colossalai/legacy/communication/utils.py
@@ -3,8 +3,8 @@
import torch
import torch.distributed as dist
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
TensorShape = Union[torch.Size, List[int], Tuple[int]]
diff --git a/colossalai/constants.py b/colossalai/legacy/constants.py
similarity index 100%
rename from colossalai/constants.py
rename to colossalai/legacy/constants.py
diff --git a/colossalai/legacy/context/__init__.py b/colossalai/legacy/context/__init__.py
new file mode 100644
index 000000000000..7027945ead7c
--- /dev/null
+++ b/colossalai/legacy/context/__init__.py
@@ -0,0 +1,4 @@
+from .parallel_context import ParallelContext
+from .parallel_mode import ParallelMode
+from .process_group_initializer import *
+from .random import *
diff --git a/colossalai/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py
similarity index 88%
rename from colossalai/context/parallel_context.py
rename to colossalai/legacy/context/parallel_context.py
index 7186f052ecec..8fdc3d6fea68 100644
--- a/colossalai/context/parallel_context.py
+++ b/colossalai/legacy/context/parallel_context.py
@@ -11,10 +11,10 @@
import torch
import torch.distributed as dist
-from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.constants import ALLOWED_MODES, INITIALIZER_MAPPING
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger
@@ -110,12 +110,12 @@ def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the global rank of the current device for `parallel_mode` to the context.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._global_ranks[parallel_mode] = rank
@@ -124,11 +124,11 @@ def get_local_rank(self, parallel_mode: ParallelMode):
"""Returns the local rank of the current device.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The local rank of the current device for `parallel_mode`.
@@ -140,12 +140,12 @@ def _add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank
@@ -154,11 +154,11 @@ def get_next_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the next device.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The global rank of the next device for `parallel_mode`.
@@ -176,11 +176,11 @@ def get_prev_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the previous device.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The global rank of the previous device for `parallel_mode`.
@@ -199,11 +199,11 @@ def is_first_rank(self, parallel_mode: ParallelMode):
among its group for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
bool: a boolean value indicating whether the current device is the first one
@@ -217,11 +217,11 @@ def is_last_rank(self, parallel_mode: ParallelMode):
among its group for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
bool: a boolean value indicating whether the current device is the first one
@@ -248,11 +248,11 @@ def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The world size for `parallel_mode`.
@@ -264,12 +264,12 @@ def _add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode corresponding to the process group
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode corresponding to the process group
world_size (int): The world size to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._world_sizes[parallel_mode] = world_size
@@ -278,11 +278,11 @@ def get_group(self, parallel_mode: ParallelMode):
"""Returns the group of the current device for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
@@ -294,12 +294,12 @@ def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
group (torch.distributed.ProcessGroup): The group to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group
@@ -308,9 +308,9 @@ def get_cpu_group(self, parallel_mode: ParallelMode):
"""Returns the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
- :type parallel_mode: :class:`colossalai.context.ParallelMode`
+ :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`
+ of :class:`colossalai.legacy.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
"""
@@ -321,11 +321,11 @@ def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
- :type parallel_mode: :class:`colossalai.context.ParallelMode`
+ :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`
+ of :class:`colossalai.legacy.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._cpu_groups[parallel_mode] = group
@@ -334,11 +334,11 @@ def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
Returns:
int: The rank of the current device for `parallel_mode` in the group.
@@ -350,12 +350,12 @@ def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
ranks (list): List of ranks to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
- of :class:`colossalai.context.ParallelMode`.
+ of :class:`colossalai.legacy.context.ParallelMode`.
"""
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
@@ -489,7 +489,7 @@ def is_initialized(self, parallel_mode: ParallelMode):
in the current system.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Returns:
bool: a boolean value indicating whether `parallel_mode` is initialized in the current system.
diff --git a/colossalai/context/parallel_mode.py b/colossalai/legacy/context/parallel_mode.py
similarity index 100%
rename from colossalai/context/parallel_mode.py
rename to colossalai/legacy/context/parallel_mode.py
diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/legacy/context/process_group_initializer/__init__.py
similarity index 100%
rename from colossalai/context/process_group_initializer/__init__.py
rename to colossalai/legacy/context/process_group_initializer/__init__.py
index d3937a947437..48d52d7b9e52 100644
--- a/colossalai/context/process_group_initializer/__init__.py
+++ b/colossalai/legacy/context/process_group_initializer/__init__.py
@@ -3,10 +3,10 @@
from .initializer_2p5d import Initializer_2p5D
from .initializer_3d import Initializer_3D
from .initializer_data import Initializer_Data
+from .initializer_model import Initializer_Model
from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor
-from .initializer_model import Initializer_Model
from .process_group_initializer import ProcessGroupInitializer
__all__ = [
diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/legacy/context/process_group_initializer/initializer_1d.py
similarity index 96%
rename from colossalai/context/process_group_initializer/initializer_1d.py
rename to colossalai/legacy/context/process_group_initializer/initializer_1d.py
index ba601d0bf61a..d853c6f06fc0 100644
--- a/colossalai/context/process_group_initializer/initializer_1d.py
+++ b/colossalai/legacy/context/process_group_initializer/initializer_1d.py
@@ -3,7 +3,7 @@
import torch.distributed as dist
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py
similarity index 98%
rename from colossalai/context/process_group_initializer/initializer_2d.py
rename to colossalai/legacy/context/process_group_initializer/initializer_2d.py
index 999cd5f0cfc6..39f6a46890b6 100644
--- a/colossalai/context/process_group_initializer/initializer_2d.py
+++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py
@@ -2,7 +2,7 @@
import torch.distributed as dist
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py
similarity index 99%
rename from colossalai/context/process_group_initializer/initializer_2p5d.py
rename to colossalai/legacy/context/process_group_initializer/initializer_2p5d.py
index b92ae2eec07e..bb7a3509572f 100644
--- a/colossalai/context/process_group_initializer/initializer_2p5d.py
+++ b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py
@@ -6,7 +6,7 @@
import torch.distributed as dist
from colossalai.context import Config
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/legacy/context/process_group_initializer/initializer_3d.py
similarity index 99%
rename from colossalai/context/process_group_initializer/initializer_3d.py
rename to colossalai/legacy/context/process_group_initializer/initializer_3d.py
index 6bca05ad7d5f..3dfbf5223b12 100644
--- a/colossalai/context/process_group_initializer/initializer_3d.py
+++ b/colossalai/legacy/context/process_group_initializer/initializer_3d.py
@@ -5,7 +5,7 @@
import torch.distributed as dist
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/legacy/context/process_group_initializer/initializer_data.py
similarity index 100%
rename from colossalai/context/process_group_initializer/initializer_data.py
rename to colossalai/legacy/context/process_group_initializer/initializer_data.py
diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/legacy/context/process_group_initializer/initializer_model.py
similarity index 100%
rename from colossalai/context/process_group_initializer/initializer_model.py
rename to colossalai/legacy/context/process_group_initializer/initializer_model.py
diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py
similarity index 100%
rename from colossalai/context/process_group_initializer/initializer_pipeline.py
rename to colossalai/legacy/context/process_group_initializer/initializer_pipeline.py
diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py
similarity index 100%
rename from colossalai/context/process_group_initializer/initializer_sequence.py
rename to colossalai/legacy/context/process_group_initializer/initializer_sequence.py
diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py
similarity index 100%
rename from colossalai/context/process_group_initializer/initializer_tensor.py
rename to colossalai/legacy/context/process_group_initializer/initializer_tensor.py
diff --git a/colossalai/context/process_group_initializer/process_group_initializer.py b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py
similarity index 100%
rename from colossalai/context/process_group_initializer/process_group_initializer.py
rename to colossalai/legacy/context/process_group_initializer/process_group_initializer.py
diff --git a/colossalai/context/random/__init__.py b/colossalai/legacy/context/random/__init__.py
similarity index 100%
rename from colossalai/context/random/__init__.py
rename to colossalai/legacy/context/random/__init__.py
diff --git a/colossalai/context/random/_helper.py b/colossalai/legacy/context/random/_helper.py
similarity index 90%
rename from colossalai/context/random/_helper.py
rename to colossalai/legacy/context/random/_helper.py
index 973c4d9faa32..4b5d5ef2fe55 100644
--- a/colossalai/context/random/_helper.py
+++ b/colossalai/legacy/context/random/_helper.py
@@ -7,8 +7,8 @@
import torch.cuda
from torch import Tensor
-from .seed_manager import SeedManager
from ..parallel_mode import ParallelMode
+from .seed_manager import SeedManager
_SEED_MANAGER = SeedManager()
@@ -53,11 +53,11 @@ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added
Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
- :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added.
+ :class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -70,7 +70,7 @@ def set_mode(parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -83,7 +83,7 @@ def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
Raises:
@@ -161,7 +161,7 @@ def wrapper(*args, **kwargs):
def moe_set_seed(seed):
if torch.cuda.is_available():
- from colossalai.core import global_context as gpc
+ from colossalai.legacy.core import global_context as gpc
global_rank = gpc.get_global_rank()
diff_seed = seed + global_rank
add_seed(ParallelMode.TENSOR, diff_seed, True)
diff --git a/colossalai/context/random/seed_manager.py b/colossalai/legacy/context/random/seed_manager.py
similarity index 86%
rename from colossalai/context/random/seed_manager.py
rename to colossalai/legacy/context/random/seed_manager.py
index 956f9001200d..b657ff7e1d32 100644
--- a/colossalai/context/random/seed_manager.py
+++ b/colossalai/legacy/context/random/seed_manager.py
@@ -4,7 +4,7 @@
import torch
from torch import Tensor
-from colossalai.context.parallel_mode import ParallelMode
+from colossalai.legacy.context.parallel_mode import ParallelMode
class SeedManager:
@@ -36,7 +36,7 @@ def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
Raises:
@@ -49,7 +49,7 @@ def set_mode(self, parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
"""
if self.current_mode:
# save the current state for current mode
@@ -63,12 +63,12 @@ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = Fal
"""Adds a seed to the seed manager for `parallel_mode`.
Args:
- parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
seed (int): The seed to be added.
overwrite (bool, optional): Whether allows to overwrite the seed that has been set already
Raises:
- AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode`
+ AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode`
or the seed for `parallel_mode` has been added.
"""
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
diff --git a/colossalai/legacy/core.py b/colossalai/legacy/core.py
new file mode 100644
index 000000000000..0aaf1ee47730
--- /dev/null
+++ b/colossalai/legacy/core.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+from colossalai.legacy.context.parallel_context import global_context
+
+__all__ = ['global_context']
diff --git a/colossalai/legacy/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py
index 9af4469f403f..930caf20c1dd 100644
--- a/colossalai/legacy/engine/_base_engine.py
+++ b/colossalai/legacy/engine/_base_engine.py
@@ -8,6 +8,7 @@
from torch.nn import Module
from torch.nn.modules.loss import _Loss
+from colossalai.interface import OptimizerWrapper
from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.schedule import (
BaseSchedule,
@@ -15,9 +16,8 @@
NonPipelineSchedule,
PipelineSchedule,
)
+from colossalai.legacy.zero.gemini import BaseOpHook, register_ophooks_recursively
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
class Engine:
@@ -27,7 +27,7 @@ class Engine:
Args:
model (``torch.nn.Module``): The neural network model.
- optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters.
+ optimizer (``colossalai.interface.OptimizerWrapper``): Optimizer for updating the parameters.
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
clip_grad_norm (float, optional): The norm of gradient clipping.
@@ -61,7 +61,7 @@ class Engine:
def __init__(self,
model: Module,
- optimizer: "ColossalaiOptimizer",
+ optimizer: "OptimizerWrapper",
criterion: Optional[_Loss] = None,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
@@ -157,7 +157,7 @@ def step(self):
"""Execute parameter update
"""
self._all_reduce_gradients()
- self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
+ self.optimizer.clip_grad_by_norm(self._clip_grad_norm)
return self.optimizer.step()
def backward(self, loss: Tensor):
diff --git a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
index c466f7e2d03b..c2270dc53a50 100644
--- a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
+++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
@@ -10,12 +10,12 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
+from colossalai.interface import OptimizerWrapper
from colossalai.legacy.engine import BaseGradientHandler
-from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import conditional_context
-class GradAccumOptimizer(ColossalaiOptimizer):
+class GradAccumOptimizer(OptimizerWrapper):
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps
before accumulation size is reached.
@@ -74,7 +74,7 @@ def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None:
if self.accumulate_step < self.accumulate_size:
pass
else:
- self.optim.clip_grad_norm(model, max_norm)
+ self.optim.clip_grad_by_norm(max_norm)
def backward(self, loss: Tensor) -> None:
"""Execute backward pass.
diff --git a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
index c5da2e55a0ed..c692ee903442 100644
--- a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
@@ -1,5 +1,5 @@
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
index 395d83da0478..e7a6df2d8ae8 100644
--- a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
@@ -1,6 +1,6 @@
from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict
diff --git a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
index 7d4d9d73afc8..3eae7d58ac95 100644
--- a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
@@ -6,7 +6,7 @@
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from colossalai.core import global_context as gpc
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
index 41098ab39d0c..38b7f5993b73 100644
--- a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
@@ -1,5 +1,5 @@
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
index 4571fd679e8c..37eed82f8a28 100644
--- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
@@ -7,11 +7,11 @@
import torch.cuda
import colossalai.legacy.communication as comm
-from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.amp.naive_amp import NaiveAMPModel
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
-from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from ._base_schedule import BaseSchedule
@@ -157,7 +157,7 @@ def load_micro_batch(self):
return self._move_to_device(micro_batch_data)
def pre_processing(self, engine):
- from colossalai.zero.legacy import ShardedModelV2
+ from colossalai.legacy.zero import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
index 385c615372f5..bf8b599a81ae 100644
--- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
@@ -6,8 +6,8 @@
import torch.cuda
import colossalai.legacy.communication.p2p_v2 as comm
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device
diff --git a/colossalai/global_variables.py b/colossalai/legacy/global_variables.py
similarity index 100%
rename from colossalai/global_variables.py
rename to colossalai/legacy/global_variables.py
diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py
new file mode 100644
index 000000000000..2c253adbaf38
--- /dev/null
+++ b/colossalai/legacy/initialize.py
@@ -0,0 +1,472 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import argparse
+import os
+import pprint
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.optim.optimizer import Optimizer
+from torch.utils.data import DataLoader
+
+from colossalai.context import Config, ConfigException
+from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.interface import OptimizerWrapper
+from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
+from colossalai.legacy.amp.naive_amp import NaiveAMPModel
+from colossalai.legacy.builder.builder import build_gradient_handler
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.engine import Engine
+from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
+from colossalai.legacy.engine.schedule import (
+ InterleavedPipelineSchedule,
+ NonPipelineSchedule,
+ PipelineSchedule,
+ get_tensor_shape,
+)
+from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
+from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
+from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
+from colossalai.logging import get_dist_logger
+from colossalai.utils import get_current_device
+from colossalai.utils.moe import sync_moe_model_param
+
+
+def get_default_parser():
+ """Reads user command line and uses an argument parser to parse the input arguments.
+ Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
+
+ Returns:
+ Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, help='path to the config file')
+ parser.add_argument('--host', type=str, help='the master address for distributed training')
+ parser.add_argument('--port', type=int, help='the master port for distributed training')
+ parser.add_argument('--world_size', type=int, help='world size for distributed training')
+ parser.add_argument('--rank', type=int, help='rank for the default process group')
+ parser.add_argument('--local_rank', type=int, help='local rank on the node')
+ parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
+ return parser
+
+
+def launch(config: Union[str, Path, Config, Dict],
+ rank: int,
+ world_size: int,
+ host: str,
+ port: int,
+ backend: str = 'nccl',
+ local_rank: int = None,
+ seed: int = 1024,
+ verbose: bool = True):
+ """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
+ arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
+
+ Args:
+ config (Union[str, dict, Config]): Config file or config file path are both acceptable
+ rank (int): Rank for the default process group
+ world_size (int): World size of the default process group
+ host (str): The master address for distributed training
+ port (str): The master port for distributed training
+ backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
+ local_rank (int, optional):
+ Rank for the process on the node and is used to set the default CUDA device,
+ defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
+ seed (int, optional): Specified random seed for every process. Defaults to 1024.
+ verbose (bool, optional): Whether to print logs. Defaults to True.
+
+ Raises:
+ Exception: Raise exception when config type is wrong
+ """
+ gpc.verbose = verbose
+
+ # set config
+ assert isinstance(config, (Config, str, Path, dict)), \
+ f'expected argument config to be Config, str or Path, but got {type(config)}'
+ if not isinstance(config, Config) and isinstance(config, dict):
+ config = Config(config)
+ if isinstance(config, (str, Path)):
+ config = Config.from_file(config)
+ gpc.load_config(config)
+
+ # init default process group
+ gpc.init_global_dist(rank, world_size, backend, host, port)
+
+ # init process groups for different parallel modes from config
+ gpc.init_parallel_groups()
+
+ # set cuda device
+ if torch.cuda.is_available():
+ # if local rank is not given, calculate automatically
+ gpc.set_device(local_rank)
+
+ # set the number of processes running on the same node
+ gpc.detect_num_processes_on_current_node()
+
+ gpc.set_seed(seed)
+
+ if verbose:
+ logger = get_dist_logger()
+ logger.info(
+ f'Distributed environment is initialized, '
+ f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
+ f'tensor parallel size: {gpc.tensor_parallel_size}',
+ ranks=[0])
+
+
+def launch_from_slurm(config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = 'nccl',
+ seed: int = 1024,
+ verbose: bool = True):
+ """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
+ set by SLURM
+
+ Args:
+ config (Union[str, dict, Config]): Config file or config file path are both acceptable
+ host (str): The master address for distributed training
+ port (str): The master port for distributed training
+ backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
+ seed (int, optional): Specified random seed for every process. Defaults to 1024.
+ verbose (bool, optional): Whether to print logs. Defaults to True.
+ """
+ try:
+ rank = int(os.environ['SLURM_PROCID'])
+ world_size = int(os.environ['SLURM_NPROCS'])
+ except KeyError as e:
+ raise RuntimeError(
+ f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
+ )
+
+ launch(config=config,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose)
+
+
+def launch_from_openmpi(config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = 'nccl',
+ seed: int = 1024,
+ verbose: bool = True):
+ """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
+ set by OpenMPI
+
+ Args:
+ config (Union[str, dict, Config]): Config file or config file path are both acceptable
+ host (str): The master address for distributed training
+ port (str): The master port for distributed training
+ backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
+ seed (int, optional): Specified random seed for every process. Defaults to 1024.
+ verbose (bool, optional): Whether to print logs. Defaults to True.
+ """
+ try:
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ except KeyError as e:
+ raise RuntimeError(
+ f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
+ )
+
+ launch(config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose)
+
+
+def launch_from_torch(config: Union[str, Path, Config, Dict],
+ backend: str = 'nccl',
+ seed: int = 1024,
+ verbose: bool = True):
+ """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
+ from the environment variables set by PyTorch
+
+ Args:
+ config (Union[str, dict, Config]): Config file or config file path are both acceptable
+ backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
+ seed (int, optional): Specified random seed for every process. Defaults to 1024.
+ verbose (bool, optional): Whether to print logs. Defaults to True.
+ """
+ try:
+ rank = int(os.environ['RANK'])
+ local_rank = int(os.environ['LOCAL_RANK'])
+ world_size = int(os.environ['WORLD_SIZE'])
+ host = os.environ['MASTER_ADDR']
+ port = int(os.environ['MASTER_PORT'])
+ except KeyError as e:
+ raise RuntimeError(
+ f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
+ )
+
+ launch(config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose)
+
+
+def initialize(model: nn.Module,
+ optimizer: Optimizer,
+ criterion: Optional[_Loss] = None,
+ train_dataloader: Optional[Iterable] = None,
+ test_dataloader: Optional[Iterable] = None,
+ lr_scheduler: Optional[_LRScheduler] = None,
+ ophooks: Optional[List[BaseOpHook]] = None,
+ verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
+ """Core function to wrap the essential training components with our functionality based on the config which is
+ loaded into gpc.config.
+
+ Args:
+ model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model.
+ optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
+ Your optimizer instance.
+ criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
+ train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
+ test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
+ lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
+ verbose (bool, optional): Whether to print logs.
+
+ Returns:
+ Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
+ A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
+ where only ``engine`` could not be None.
+ """
+ # get logger
+ logger = get_dist_logger()
+ gpc.verbose = verbose
+
+ # get config from gpc
+ config = gpc.config
+
+ # print config
+ if verbose:
+ logger.info(
+ f"\n========== Your Config ========\n"
+ f"{pprint.pformat(gpc.config)}\n"
+ f"================================\n",
+ ranks=[0])
+
+ # cudnn
+ cudnn_benchmark = config.get('cudnn_benchmark', False)
+ cudnn_deterministic = config.get('cudnn_deterministic', False)
+ torch.backends.cudnn.benchmark = cudnn_benchmark
+ torch.backends.cudnn.deterministic = cudnn_deterministic
+ if verbose:
+ logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
+
+ # zero
+ use_zero = hasattr(gpc.config, 'zero')
+ if use_zero:
+ zero_cfg = gpc.config.get('zero', None)
+ if zero_cfg is not None:
+ cfg_ = zero_cfg.copy()
+ else:
+ cfg_ = {}
+ optimizer_config = zero_cfg.get('optimizer_config', None)
+ model_config = zero_cfg.get('model_config', None)
+ model, optimizer = convert_to_zero_v2(model,
+ optimizer,
+ model_config=model_config,
+ optimizer_config=optimizer_config)
+
+ logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
+ else:
+ if isinstance(model, nn.Module):
+ # first sync model across dp ranks
+ model.to(get_current_device())
+ elif isinstance(model, Callable):
+ model = model().to(get_current_device())
+
+ # optimizer maybe a optimizer_cls
+ if isinstance(optimizer, Callable):
+ optimizer = optimizer(model.parameters())
+ logger.warning("Initializing an non ZeRO model with optimizer class")
+
+ if not use_zero:
+ if is_using_sequence():
+ sync_model_param(model, ParallelMode.SEQUENCE_DP)
+ elif MOE_CONTEXT.is_initialized:
+ sync_moe_model_param(model)
+ elif is_using_ddp():
+ sync_model_param(model, ParallelMode.DATA)
+ else:
+ logger.warning(
+ "The parameters of models is not automatically synchronized.\n"
+ "Please make sure that all parameters are the same in data parallel group.",
+ ranks=[0])
+
+ # check amp and zero
+ fp16_cfg = gpc.config.get('fp16', None)
+
+ if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
+ raise ConfigException(
+ "It is not allowed to set fp16 and zero configuration in your config file at the same time")
+
+ # clip grad norm
+ clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
+
+ # initialize amp
+ amp_mode = None
+ if fp16_cfg is not None and fp16_cfg.mode is not None:
+ cfg_ = fp16_cfg.copy()
+ amp_mode = cfg_.pop('mode')
+ if is_using_pp():
+ assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
+ if amp_mode == AMP_TYPE.NAIVE:
+ cfg_['clip_grad_norm'] = clip_grad_norm
+ model, optimizer, criterion = convert_to_amp(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ mode=amp_mode,
+ amp_config=cfg_)
+
+ # get torch ddp config
+ torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
+
+ # gradient handler
+ gradient_handler_cfg = gpc.config.get('gradient_handler', None)
+ if gradient_handler_cfg is None:
+ # if gradient handler is not specified in the configuration file,
+ # check in the following order
+ # 1. if optimizer is ZERO, then use zero grad handler
+ # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
+ # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
+ if isinstance(optimizer, ShardedOptimizerV2):
+ gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
+ if verbose:
+ logger.info(
+ "Training with zero is detected, ZeROGradientHandler is automatically "
+ "added even though not specified in the configuration",
+ ranks=[0])
+ elif is_using_ddp() and MOE_CONTEXT.is_initialized:
+ gradient_handler_cfg = [dict(type='MoeGradientHandler')]
+ if verbose:
+ logger.info(
+ "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
+ "added even though not specified in the configuration",
+ ranks=[0])
+ elif is_using_sequence():
+ model = DDP(model,
+ process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
+ device_ids=[torch.cuda.current_device()],
+ **torch_ddp_cfg)
+ if verbose:
+ logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
+ ranks=[0])
+ elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
+ model = DDP(model,
+ process_group=gpc.get_group(ParallelMode.DATA),
+ device_ids=[torch.cuda.current_device()],
+ **torch_ddp_cfg)
+ if verbose:
+ logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
+ elif is_using_ddp():
+ gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
+ if verbose:
+ logger.info(
+ "Data parallel training is detected when using pipeline parallel, "
+ "DataParallelGradientHandler is automatically "
+ "added even though not specified in the configuration",
+ ranks=[0])
+ # add pipeline parallel gradient handler, if pipeline shared module is detected
+ for param in model.parameters():
+ if getattr(param, 'pipeline_shared_module_pg', None) is not None:
+ if gradient_handler_cfg is None:
+ gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
+ else:
+ gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
+ if verbose:
+ logger.info(
+ "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
+ "added even though not specified in the configuration",
+ ranks=[0])
+ break
+ else:
+ if not isinstance(gradient_handler_cfg, list):
+ raise ConfigException(
+ f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
+ )
+
+ # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
+ # to avoid duplicated buffer synchronization
+ if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
+ model.module.sync_buffer = False
+
+ # initialize schedule for engine
+ if is_using_pp():
+ tensor_shape = get_tensor_shape()
+ use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
+ if gpc.is_initialized(ParallelMode.PARALLEL_1D):
+ scatter_gather = True
+ else:
+ scatter_gather = False
+ if use_interleaved:
+ if isinstance(model, nn.Sequential):
+ model = nn.ModuleList([model])
+ schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
+ gpc.config.model.num_chunks,
+ tensor_shape=tensor_shape,
+ scatter_gather_tensors=scatter_gather)
+ else:
+ schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
+ tensor_shape=tensor_shape,
+ scatter_gather_tensors=scatter_gather)
+ else:
+ schedule = NonPipelineSchedule()
+
+ if gradient_handler_cfg is None:
+ gradient_handlers = None
+ if verbose and not isinstance(model, DDP):
+ logger.warning(
+ "No PyTorch DDP or gradient handler is set up, please make sure you do not need "
+ "to all-reduce the gradients after a training step.",
+ ranks=[0])
+ else:
+ gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
+
+ # check if optimizer is OptimizerWrapper
+ if not isinstance(optimizer, (OptimizerWrapper, ShardedOptimizerV2)):
+ optimizer = OptimizerWrapper(optim=optimizer)
+
+ # gradient accumulation
+ grad_accum_size = gpc.config.get('gradient_accumulation', None)
+ if grad_accum_size is not None:
+ optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
+ model=model,
+ optimizer=optimizer,
+ dataloader=train_dataloader,
+ accumulate_size=grad_accum_size,
+ gradient_handlers=gradient_handlers,
+ lr_scheduler=lr_scheduler)
+ engine = Engine(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ gradient_handlers=gradient_handlers,
+ clip_grad_norm=clip_grad_norm,
+ ophook_list=ophooks,
+ schedule=schedule)
+
+ return engine, train_dataloader, test_dataloader, lr_scheduler
diff --git a/colossalai/legacy/nn/__init__.py b/colossalai/legacy/nn/__init__.py
index 500162901905..d30ebf8d5406 100644
--- a/colossalai/legacy/nn/__init__.py
+++ b/colossalai/legacy/nn/__init__.py
@@ -1,4 +1,3 @@
-from ._ops import *
from .layer import *
from .loss import *
from .metric import *
diff --git a/colossalai/legacy/nn/_ops/__init__.py b/colossalai/legacy/nn/_ops/__init__.py
index 4991ad9a2217..9a35d02ce5ed 100644
--- a/colossalai/legacy/nn/_ops/__init__.py
+++ b/colossalai/legacy/nn/_ops/__init__.py
@@ -1,9 +1 @@
-from .addmm import colo_addmm
-from .batch_norm import colo_batch_norm
-from .element_wise import *
-from .embedding import colo_embedding
-from .embedding_bag import colo_embedding_bag
-from .layernorm import colo_layernorm
-from .linear import colo_linear
-from .loss import colo_cross_entropy
-from .view import colo_view
+from ._utils import *
diff --git a/colossalai/legacy/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py
index 131c2154771b..a4228fa2116e 100644
--- a/colossalai/legacy/nn/_ops/_utils.py
+++ b/colossalai/legacy/nn/_ops/_utils.py
@@ -3,9 +3,10 @@
import torch
import torch.distributed as dist
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.nn.layer.utils import divide
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
+from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup
+from colossalai.tensor import ColoTensor
GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float]
diff --git a/colossalai/legacy/nn/_ops/addmm.py b/colossalai/legacy/nn/_ops/addmm.py
deleted file mode 100644
index 660b48a71d57..000000000000
--- a/colossalai/legacy/nn/_ops/addmm.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import torch
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input
-
-
-def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
- alpha: Number) -> ColoTensor:
- # mat1:S[1] x mat2:S[0] = Output:P
- # beta * input + alpha * All-Reduce(Output) = res
-
- mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group())
-
- # Output:P
- partial_output = torch.mm(mat1, mat2)
- # Reduce(Output)
- output = reduce_input(partial_output, mat2.get_process_group())
- # input
- assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
- output = beta * input_tensor + alpha * output
- output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group()))
- return output
-
-
-def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
- alpha: Number) -> ColoTensor:
- # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
- compute_spec = mat2.compute_spec
- mat1 = mat1.redistribute(ReplicaSpec())
- mat1 = reduce_grad(mat1, mat1.get_process_group())
-
- output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
- output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
- ComputeSpec(ComputePattern.TP1D))
- output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
-
- if compute_spec.output_replicate:
- return output.to_replicate()
- else:
- return output
-
-
-def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
- alpha: Number) -> ColoTensor:
- assert mode in ('row', 'col')
- funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
- return funcs[mode](input_tensor, mat1, mat2, beta, alpha)
-
-
-@colo_op_impl(torch.addmm)
-def colo_addmm(input_tensor: GeneralTensor,
- mat1: ColoTensor,
- mat2: ColoTensor,
- beta: Number = 1,
- alpha: Number = 1,
- **kargs) -> ColoTensor:
- """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
- This method computes a linear.
- """
- # At least one of the tensor should be ColoTensor
- assert isinstance(mat2, ColoTensor)
- input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group())
- mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group())
-
- # Add communication logic before and after linear call.
- ret_tensor = None
- if not mat2.has_compute_spec(): # No Model Parallel Applied
- assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
- assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
- ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor,
- mat1,
- mat2,
- beta=beta,
- alpha=alpha,
- **kargs),
- spec=ColoTensorSpec(mat2.get_process_group()))
- elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
- if mat2.is_shard_1drow() and input_tensor.is_replicate():
- mode = 'row'
- elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()):
- mode = 'col'
- else:
- raise NotImplementedError
- ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
- else:
- raise NotImplementedError
-
- return ret_tensor
diff --git a/colossalai/legacy/nn/_ops/batch_norm.py b/colossalai/legacy/nn/_ops/batch_norm.py
deleted file mode 100644
index 54ecc88f420a..000000000000
--- a/colossalai/legacy/nn/_ops/batch_norm.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from typing import Optional
-
-import torch.nn.functional as F
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-from ._utils import GeneralTensor, convert_to_colo_tensor
-
-
-@colo_op_impl(F.batch_norm)
-def colo_batch_norm(
- input: GeneralTensor,
- running_mean: Optional[GeneralTensor],
- running_var: Optional[GeneralTensor],
- weight: Optional[GeneralTensor] = None,
- bias: Optional[GeneralTensor] = None,
- training: bool = False,
- momentum: float = 0.1,
- eps: float = 1e-5,
-):
- assert isinstance(weight, ColoTensor)
- running_mean = running_mean.detach()
- running_var = running_var.detach()
-
- input = convert_to_colo_tensor(input, weight.get_process_group())
- bias = convert_to_colo_tensor(bias, weight.get_process_group())
- input = input.redistribute(ReplicaSpec())
- bias = bias.redistribute(ReplicaSpec())
-
- output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
- output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
- return output
diff --git a/colossalai/legacy/nn/_ops/element_wise.py b/colossalai/legacy/nn/_ops/element_wise.py
deleted file mode 100644
index 2de51e24a6dd..000000000000
--- a/colossalai/legacy/nn/_ops/element_wise.py
+++ /dev/null
@@ -1,250 +0,0 @@
-import torch
-import torch.nn.functional as F
-from torch import Tensor
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-from ._utils import GeneralTensor, convert_to_colo_tensor
-
-
-def register_elementwise_op(op):
-
- @colo_op_impl(op)
- def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
- """
- Handles ``__torch_function__`` dispatch for the elementwise op such
- as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
- This method computes on either a normal tensor or a sharded tensor.
- """
- if 'inplace' in kwargs:
- # TODO(jiaruifang) inplace will cause bugs
- input_tensor = input_tensor.clone()
- return op(input_tensor, *args, **kwargs)
- else:
- output = op(input_tensor, *args, **kwargs)
- # return output
- if isinstance(input_tensor, ColoTensor):
- if isinstance(output, str):
- return output
- if not isinstance(output, torch.Tensor):
- raise NotImplementedError
- return ColoTensor.from_torch_tensor(output,
- spec=ColoTensorSpec(input_tensor.get_process_group(),
- dist_attr=input_tensor.dist_spec))
-
-
-# @colo_op_impl(torch.relu_)
-# def elementwise_op(input_tensor):
-# torch.relu_(input_tensor.data)
-# return input_tensor
-
-# @colo_op_impl(Tensor.add_)
-# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
-# input_tensor = input_tensor.data.add_(*args, **kwargs)
-# return input_tensor
-
-# Tensor op
-register_elementwise_op(Tensor.abs)
-register_elementwise_op(Tensor.absolute)
-register_elementwise_op(Tensor.acos)
-register_elementwise_op(Tensor.arccos)
-register_elementwise_op(Tensor.angle)
-register_elementwise_op(Tensor.asin)
-register_elementwise_op(Tensor.arcsin)
-register_elementwise_op(Tensor.atan)
-register_elementwise_op(Tensor.arctan)
-register_elementwise_op(Tensor.all)
-register_elementwise_op(Tensor.any)
-register_elementwise_op(Tensor.bernoulli)
-register_elementwise_op(Tensor.bfloat16)
-register_elementwise_op(Tensor.bitwise_not)
-register_elementwise_op(Tensor.bool)
-register_elementwise_op(Tensor.byte)
-register_elementwise_op(Tensor.ceil)
-register_elementwise_op(Tensor.char)
-register_elementwise_op(Tensor.clamp)
-register_elementwise_op(Tensor.clamp_max)
-register_elementwise_op(Tensor.clamp_min)
-register_elementwise_op(Tensor.clip)
-register_elementwise_op(Tensor.clone)
-register_elementwise_op(Tensor.contiguous)
-register_elementwise_op(Tensor.copysign)
-register_elementwise_op(Tensor.cos)
-register_elementwise_op(Tensor.cosh)
-register_elementwise_op(Tensor.acosh)
-register_elementwise_op(Tensor.arccosh)
-register_elementwise_op(Tensor.cpu)
-register_elementwise_op(Tensor.cuda)
-register_elementwise_op(Tensor.deg2rad)
-register_elementwise_op(Tensor.detach)
-register_elementwise_op(Tensor.digamma)
-register_elementwise_op(Tensor.double)
-register_elementwise_op(Tensor.erf)
-register_elementwise_op(Tensor.erfc)
-register_elementwise_op(Tensor.erfinv)
-register_elementwise_op(Tensor.exp)
-register_elementwise_op(Tensor.expm1)
-register_elementwise_op(Tensor.fix)
-register_elementwise_op(Tensor.trunc)
-register_elementwise_op(Tensor.float)
-register_elementwise_op(Tensor.float_power)
-register_elementwise_op(Tensor.floor)
-register_elementwise_op(Tensor.frac)
-register_elementwise_op(Tensor.half)
-register_elementwise_op(Tensor.hardshrink)
-register_elementwise_op(Tensor.heaviside)
-register_elementwise_op(Tensor.i0)
-register_elementwise_op(Tensor.int)
-register_elementwise_op(Tensor.isfinite)
-register_elementwise_op(Tensor.isinf)
-register_elementwise_op(Tensor.isposinf)
-register_elementwise_op(Tensor.isneginf)
-register_elementwise_op(Tensor.isnan)
-register_elementwise_op(Tensor.lgamma)
-register_elementwise_op(Tensor.log)
-register_elementwise_op(Tensor.log10)
-register_elementwise_op(Tensor.log1p)
-register_elementwise_op(Tensor.log2)
-register_elementwise_op(Tensor.logical_not)
-register_elementwise_op(Tensor.logit)
-register_elementwise_op(Tensor.long)
-register_elementwise_op(Tensor.nan_to_num)
-register_elementwise_op(Tensor.neg)
-register_elementwise_op(Tensor.negative)
-register_elementwise_op(Tensor.positive)
-register_elementwise_op(Tensor.pow)
-register_elementwise_op(Tensor.rad2deg)
-register_elementwise_op(Tensor.reciprocal)
-register_elementwise_op(Tensor.round)
-register_elementwise_op(Tensor.rsqrt)
-register_elementwise_op(Tensor.short)
-register_elementwise_op(Tensor.sigmoid)
-register_elementwise_op(Tensor.sign)
-register_elementwise_op(Tensor.signbit)
-register_elementwise_op(Tensor.sgn)
-register_elementwise_op(Tensor.sin)
-register_elementwise_op(Tensor.sinc)
-register_elementwise_op(Tensor.sinh)
-register_elementwise_op(Tensor.asinh)
-register_elementwise_op(Tensor.arcsinh)
-register_elementwise_op(Tensor.sqrt)
-register_elementwise_op(Tensor.square)
-register_elementwise_op(Tensor.to)
-register_elementwise_op(Tensor.tan)
-register_elementwise_op(Tensor.tanh)
-register_elementwise_op(Tensor.atanh)
-register_elementwise_op(Tensor.arctanh)
-register_elementwise_op(Tensor.type)
-register_elementwise_op(Tensor.type_as)
-
-# torch OP
-register_elementwise_op(torch.abs)
-register_elementwise_op(torch.absolute)
-register_elementwise_op(torch.acos)
-register_elementwise_op(torch.arccos)
-register_elementwise_op(torch.angle)
-register_elementwise_op(torch.asin)
-register_elementwise_op(torch.arcsin)
-register_elementwise_op(torch.atan)
-register_elementwise_op(torch.arctan)
-register_elementwise_op(torch.all)
-register_elementwise_op(torch.any)
-register_elementwise_op(torch.bernoulli)
-register_elementwise_op(torch.bitwise_not)
-register_elementwise_op(torch.ceil)
-register_elementwise_op(torch.clamp)
-register_elementwise_op(torch.clamp_max)
-register_elementwise_op(torch.clamp_min)
-register_elementwise_op(torch.clip)
-register_elementwise_op(torch.clone)
-register_elementwise_op(torch.copysign)
-register_elementwise_op(torch.cos)
-register_elementwise_op(torch.cosh)
-register_elementwise_op(torch.acosh)
-register_elementwise_op(torch.arccosh)
-register_elementwise_op(torch.deg2rad)
-register_elementwise_op(torch.digamma)
-register_elementwise_op(torch.erf)
-register_elementwise_op(torch.erfc)
-register_elementwise_op(torch.erfinv)
-register_elementwise_op(torch.exp)
-register_elementwise_op(torch.expm1)
-register_elementwise_op(torch.fix)
-register_elementwise_op(torch.trunc)
-register_elementwise_op(torch.float_power)
-register_elementwise_op(torch.floor)
-register_elementwise_op(torch.frac)
-register_elementwise_op(torch.hardshrink)
-register_elementwise_op(torch.heaviside)
-register_elementwise_op(torch.i0)
-register_elementwise_op(torch.isfinite)
-register_elementwise_op(torch.isinf)
-register_elementwise_op(torch.isposinf)
-register_elementwise_op(torch.isneginf)
-register_elementwise_op(torch.isnan)
-register_elementwise_op(torch.lgamma)
-register_elementwise_op(torch.log)
-register_elementwise_op(torch.log10)
-register_elementwise_op(torch.log1p)
-register_elementwise_op(torch.log2)
-register_elementwise_op(torch.logical_not)
-register_elementwise_op(torch.logit)
-register_elementwise_op(torch.nan_to_num)
-register_elementwise_op(torch.neg)
-register_elementwise_op(torch.negative)
-register_elementwise_op(torch.positive)
-register_elementwise_op(torch.pow)
-register_elementwise_op(torch.rad2deg)
-register_elementwise_op(torch.reciprocal)
-register_elementwise_op(torch.round)
-register_elementwise_op(torch.rsqrt)
-register_elementwise_op(torch.sigmoid)
-register_elementwise_op(torch.sign)
-register_elementwise_op(torch.signbit)
-register_elementwise_op(torch.sgn)
-register_elementwise_op(torch.sin)
-register_elementwise_op(torch.sinc)
-register_elementwise_op(torch.sinh)
-register_elementwise_op(torch.asinh)
-register_elementwise_op(torch.arcsinh)
-register_elementwise_op(torch.sqrt)
-register_elementwise_op(torch.square)
-register_elementwise_op(torch.tan)
-register_elementwise_op(torch.tanh)
-register_elementwise_op(torch.atanh)
-register_elementwise_op(torch.arctanh)
-register_elementwise_op(torch.zeros_like)
-
-# nn.functional OP
-register_elementwise_op(F.threshold)
-register_elementwise_op(F.relu)
-register_elementwise_op(F.hardtanh)
-register_elementwise_op(F.hardswish)
-register_elementwise_op(F.relu6)
-register_elementwise_op(F.elu)
-register_elementwise_op(F.selu)
-register_elementwise_op(F.celu)
-register_elementwise_op(F.leaky_relu)
-register_elementwise_op(F.prelu)
-register_elementwise_op(F.rrelu)
-register_elementwise_op(F.gelu)
-register_elementwise_op(F.logsigmoid)
-register_elementwise_op(F.hardshrink)
-register_elementwise_op(F.tanhshrink)
-register_elementwise_op(F.softsign)
-register_elementwise_op(F.softplus)
-register_elementwise_op(F.softmin)
-register_elementwise_op(F.softmax)
-register_elementwise_op(F.softshrink)
-register_elementwise_op(F.gumbel_softmax)
-register_elementwise_op(F.log_softmax)
-register_elementwise_op(F.tanh)
-register_elementwise_op(F.sigmoid)
-register_elementwise_op(F.hardsigmoid)
-register_elementwise_op(F.silu)
-register_elementwise_op(F.mish)
-# TODO(ver217): dropout handles seed
-register_elementwise_op(F.dropout)
-register_elementwise_op(F.alpha_dropout)
-register_elementwise_op(F.feature_alpha_dropout)
diff --git a/colossalai/legacy/nn/_ops/embedding.py b/colossalai/legacy/nn/_ops/embedding.py
deleted file mode 100644
index b145d1763380..000000000000
--- a/colossalai/legacy/nn/_ops/embedding.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from typing import Optional
-
-import torch.nn.functional as F
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
-
-
-def colo_embedding_1Dcol(input_tensor: ColoTensor,
- weight: ColoTensor,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False) -> ColoTensor:
- # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
- # Gather splitted lookup table
- input_tensor = input_tensor.redistribute(ReplicaSpec())
-
- output_parallel = F.embedding(input_tensor,
- weight,
- padding_idx=padding_idx,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- sparse=sparse)
- output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
- ComputeSpec(ComputePattern.TP1D))
- output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
-
- compute_spec = weight.compute_spec
-
- if compute_spec.output_replicate:
- return output.to_replicate()
- else:
- return output
-
-
-def colo_embedding_1Drow(input_tensor: ColoTensor,
- weight: ColoTensor,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False) -> ColoTensor:
- # embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim]
- # get the index of current segment and mask other segments with 0
-
- # get complete input tensor through all-gather
- input_tensor = input_tensor.redistribute(ReplicaSpec())
-
- # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
- tensor_parallel_rank = weight.get_process_group().tp_local_rank()
- num_embeddings_per_partition = weight.size_local(0)
- vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
- vocab_end_index = vocab_start_index + num_embeddings_per_partition
-
- # build the mask.
- input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
- # mask the input.
- # TODO(jzy) masked_input may be an activation managed by ColoTensor.
- masked_input = input_tensor - vocab_start_index
- masked_input[input_mask] = 0
-
- partial_output = F.embedding(masked_input,
- weight,
- padding_idx=padding_idx,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- sparse=sparse)
-
- # Mask the output embedding.
- partial_output[input_mask, :] = 0.
- # Reduce across all the model parallel GPUs.
- output = reduce_input(partial_output, weight.get_process_group())
- output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
- return output
-
-
-def colo_embedding_1d(mode: str,
- input_tensor: ColoTensor,
- weight: ColoTensor,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False) -> ColoTensor:
- assert mode in ('row', 'col')
- funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
- return funcs[mode](input_tensor,
- weight,
- padding_idx=padding_idx,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- sparse=sparse)
-
-
-@colo_op_impl(F.embedding)
-def colo_embedding(input_tensor: GeneralTensor,
- weight: GeneralTensor,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False):
- """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
- This method looks up an embedding table.
- """
- assert isinstance(weight, ColoTensor)
- input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
-
- if not weight.has_compute_spec(): # No Model Parallel Applied
- assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
- return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
- weight,
- padding_idx=padding_idx,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- sparse=sparse),
- spec=ColoTensorSpec(weight.get_process_group()))
- elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
- if weight.is_shard_1drow():
- mode = 'row'
- elif weight.is_shard_1dcol():
- mode = 'col'
- else:
- raise NotImplementedError
- return colo_embedding_1d(mode,
- input_tensor,
- weight,
- padding_idx=padding_idx,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- sparse=sparse)
- else:
- raise NotImplementedError
diff --git a/colossalai/legacy/nn/_ops/embedding_bag.py b/colossalai/legacy/nn/_ops/embedding_bag.py
deleted file mode 100644
index 9a656d5871a3..000000000000
--- a/colossalai/legacy/nn/_ops/embedding_bag.py
+++ /dev/null
@@ -1,127 +0,0 @@
-from typing import Optional
-
-import torch.nn.functional as F
-from torch import Tensor
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-from ._utils import GeneralTensor, convert_to_colo_tensor
-
-
-def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
- weight: ColoTensor,
- offsets: Optional[Tensor] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2,
- scale_grad_by_freq: bool = False,
- mode: str = "mean",
- sparse: bool = False,
- per_sample_weights: Optional[Tensor] = None,
- include_last_offset: bool = False,
- padding_idx: Optional[int] = None) -> ColoTensor:
- # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
- # Gather splitted lookup table
- pg = weight.get_process_group()
- input_tensor = input_tensor.redistribute(ReplicaSpec())
-
- output_parallel = F.embedding_bag(input_tensor,
- weight,
- offsets=offsets,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- mode=mode,
- sparse=sparse,
- per_sample_weights=per_sample_weights,
- include_last_offset=include_last_offset,
- padding_idx=padding_idx)
- output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
- output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
-
- if weight.compute_spec.output_replicate:
- return output.to_replicate()
- else:
- return output
-
-
-def colo_embedding_bag_1d(tp_mode: str,
- input_tensor: ColoTensor,
- weight: ColoTensor,
- offsets: Optional[Tensor] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2,
- scale_grad_by_freq: bool = False,
- mode: str = "mean",
- sparse: bool = False,
- per_sample_weights: Optional[Tensor] = None,
- include_last_offset: bool = False,
- padding_idx: Optional[int] = None) -> ColoTensor:
- assert tp_mode in ('col',)
- funcs = {'col': colo_embedding_bag_1Dcol}
- return funcs[tp_mode](input_tensor,
- weight,
- offsets=offsets,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- mode=mode,
- sparse=sparse,
- per_sample_weights=per_sample_weights,
- include_last_offset=include_last_offset,
- padding_idx=padding_idx)
-
-
-@colo_op_impl(F.embedding_bag)
-def colo_embedding_bag(input_tensor: GeneralTensor,
- weight: GeneralTensor,
- offsets: Optional[Tensor] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2,
- scale_grad_by_freq: bool = False,
- mode: str = "mean",
- sparse: bool = False,
- per_sample_weights: Optional[Tensor] = None,
- include_last_offset: bool = False,
- padding_idx: Optional[int] = None):
- """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
- This method looks up an embedding table.
- """
- assert isinstance(weight, ColoTensor)
- input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
-
- # Handle different parallel actions.
-
- if not weight.has_compute_spec(): # No Model Parallel Applied
- assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
- return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
- weight,
- offsets=offsets,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- mode=mode,
- sparse=sparse,
- per_sample_weights=per_sample_weights,
- include_last_offset=include_last_offset,
- padding_idx=padding_idx),
- spec=ColoTensorSpec(weight.get_process_group()))
- elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
- if weight.is_shard_1dcol():
- tp_mode = 'col'
- else:
- raise NotImplementedError
- return colo_embedding_bag_1d(tp_mode,
- input_tensor,
- weight,
- offsets=offsets,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- mode=mode,
- sparse=sparse,
- per_sample_weights=per_sample_weights,
- include_last_offset=include_last_offset,
- padding_idx=padding_idx)
- else:
- raise NotImplementedError
diff --git a/colossalai/legacy/nn/_ops/layernorm.py b/colossalai/legacy/nn/_ops/layernorm.py
deleted file mode 100644
index 9960c5d48096..000000000000
--- a/colossalai/legacy/nn/_ops/layernorm.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from typing import List, Optional
-
-import torch.nn.functional as F
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-from ._utils import GeneralTensor, convert_to_colo_tensor
-
-
-@colo_op_impl(F.layer_norm)
-def colo_layernorm(
- input_tensor: GeneralTensor,
- normalized_shape: List[int],
- weight: Optional[GeneralTensor] = None,
- bias: Optional[GeneralTensor] = None,
- eps: float = 1e-5,
-):
- assert isinstance(weight, ColoTensor)
- input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
- bias = convert_to_colo_tensor(bias, weight.get_process_group())
- input_tensor = input_tensor.redistribute(ReplicaSpec())
-
- output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
- output = ColoTensor.from_torch_tensor(tensor=output,
- spec=ColoTensorSpec(pg=input_tensor.get_process_group(),
- dist_attr=input_tensor.dist_spec))
- return output
diff --git a/colossalai/legacy/nn/_ops/linear.py b/colossalai/legacy/nn/_ops/linear.py
deleted file mode 100644
index 2f2088c61fa8..000000000000
--- a/colossalai/legacy/nn/_ops/linear.py
+++ /dev/null
@@ -1,171 +0,0 @@
-from copy import deepcopy
-from typing import Optional
-
-import torch.nn.functional as F
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
-from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input
-
-
-def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
- # Input:S[1] x Weight:S[0] = Output:P
- # All-Reduce(Output) + bias = res
- # Input:S[1]
- pg = weight.get_process_group()
- input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg)
-
- # Output:P
- partial_output = F.linear(input_tensor, weight)
- # Reduce(Output)
-
- output = reduce_input(partial_output, pg)
- # Bias
- if bias is not None:
- assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
- output = output + bias
-
- output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
- return output
-
-
-def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
- # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
- # All-Gather(Output)
- # Input:B
- compute_spec = weight.compute_spec
- input_tensor = input_tensor.redistribute(ReplicaSpec())
- input_parallel = reduce_grad(input_tensor, weight.get_process_group())
-
- output_parallel = F.linear(input_parallel, weight, bias)
- output = ColoTensor.from_torch_tensor(output_parallel,
- spec=ColoTensorSpec(weight.get_process_group(),
- ShardSpec([-1], [weight.get_tp_world_size()]),
- ComputeSpec(ComputePattern.TP1D)))
- if compute_spec.output_replicate:
- return output.to_replicate()
- else:
- return output
-
-
-def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
- assert mode in ('row', 'col')
- funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol}
- return funcs[mode](input_tensor, weight, bias)
-
-
-# @register_colo_graph(input_pos=[1], param_pos=[2, 3])
-def colo_linear_imp(input_tensor: GeneralTensor,
- weight: GeneralTensor,
- bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
- """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
- This method computes a linear.
- """
- assert isinstance(weight, ColoTensor)
- pg = weight.get_process_group()
- assert pg
- input_tensor = convert_to_colo_tensor(input_tensor, pg)
- bias = convert_to_colo_tensor(bias, pg)
- # input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
-
- # Add communication logic before and after linear call.
- ret_tensor = None
- if not weight.has_compute_spec(): # No Model Parallel Applied
- assert weight.is_replicate(), 'Invalid weight spec for native Linear op'
- assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op'
- ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg))
- elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
- if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()):
- mode = 'row'
- elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()):
- mode = 'col'
- else:
- raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}")
- ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
- else:
- raise NotImplementedError
-
- return ret_tensor
-
-
-def _new_colo_linear_imp(input_tensor: GeneralTensor,
- weight: GeneralTensor,
- bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
- """
- A tentative function to compute the distributed linear layer with the latest sharding spec.
- This function is subject to future change as the current sharding API is not stable.
- """
- # get mesh info
- input_sharding_seq = input_tensor.sharding_spec.sharding_sequence
- weight_sharding_seq = weight.sharding_spec.sharding_sequence
- if bias is not None:
- bias_sharding_seq = bias.sharding_spec.sharding_sequence
- device_mesh = weight.sharding_spec.device_mesh
- pg_axis0 = weight.pg_axis0
- pg_axis1 = weight.pg_axis1
-
- # the last dim of input should have the same spec as the first dim of weight
- # the weight is transposed, so we look at the second dimension
- assert input_sharding_seq[-1] == weight_sharding_seq[1]
-
- if bias is not None:
- assert bias_sharding_seq[0] == weight_sharding_seq[0]
-
- # compute the output sharding sequence
- # as weight is transposed, so we look at the first dimension
- output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1]
- output_shard_seq = deepcopy(output_shard_seq)
-
- # TODO: add reduce grad logic
-
- # handle column and row parallel linear
- # by reusing the implementation above
- out = F.linear(input_tensor, weight)
-
- # run all reduce if necessary
- last_dim_spec = input_sharding_seq[-1]
- if last_dim_spec.is_replica:
- pass
- elif last_dim_spec.shard_list is not None:
- for dim in last_dim_spec.shard_list:
- if dim == 0:
- reduce_input(out, pg_axis0)
- elif dim == 1:
- reduce_input(out, pg_axis1)
- else:
- raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected")
- # add bias
- if bias is not None:
- out += bias
-
- # convert shard seq to partition dict
- output_partition_dict = {}
- for index, dim_spec in enumerate(output_shard_seq):
- if not dim_spec.is_replica:
- if index not in output_partition_dict:
- output_partition_dict[index] = []
- output_partition_dict[index].extend(dim_spec.shard_list)
-
- entire_shape = out.shape
- output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict)
- ret_tensor = ColoTensor.from_torch_tensor(out)
- setattr(ret_tensor, 'sharding_spec', output_sharding_spec)
- return ret_tensor
-
-
-def _has_sharding_spec(tensor):
- """
- A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is
- set as the attribute `sharding_spec` on a tensor.
- """
- return hasattr(tensor, 'sharding_spec')
-
-
-@colo_op_impl(F.linear)
-def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
- if _has_sharding_spec(weight):
- return _new_colo_linear_imp(input, weight, bias)
- else:
- return colo_linear_imp(input, weight, bias)
diff --git a/colossalai/legacy/nn/_ops/loss.py b/colossalai/legacy/nn/_ops/loss.py
deleted file mode 100644
index 90efbfa36f2a..000000000000
--- a/colossalai/legacy/nn/_ops/loss.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from typing import Optional
-
-import torch
-import torch.nn.functional as F
-
-from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
-from colossalai.tensor import ColoTensor, ColoTensorSpec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-from ._utils import GeneralTensor, convert_to_colo_tensor
-
-
-@colo_op_impl(F.cross_entropy)
-def colo_cross_entropy(input_tensor: GeneralTensor,
- target: GeneralTensor,
- weight: Optional[GeneralTensor] = None,
- size_average: Optional[bool] = None,
- ignore_index: int = -100,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- label_smoothing: float = 0.0):
- assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor)
- pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor)
- weight = convert_to_colo_tensor(weight, pg)
- target = convert_to_colo_tensor(target, pg)
- input_tensor = convert_to_colo_tensor(input_tensor, pg)
-
- if input_tensor.is_replicate(): # Input is gathered
- assert target.is_replicate() and (weight is None or weight.is_replicate()), \
- "Target tensor and weight tensor both should be complete"
- output = F.cross_entropy(input_tensor,
- target,
- weight=weight,
- size_average=size_average,
- ignore_index=ignore_index,
- reduce=reduce,
- reduction=reduction,
- label_smoothing=label_smoothing)
- return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
- elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
- if input_tensor.is_shard_1dcol():
- assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in"
- assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function"
- output = VocabParallelCrossEntropyLoss1D()(input_tensor,
- target,
- process_group=input_tensor.process_group.tp_process_group())
- return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
- else:
- raise NotImplementedError
- else:
- raise NotImplementedError
diff --git a/colossalai/legacy/nn/_ops/view.py b/colossalai/legacy/nn/_ops/view.py
deleted file mode 100644
index 3c0bc52337ce..000000000000
--- a/colossalai/legacy/nn/_ops/view.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import operator
-from functools import reduce
-from typing import Optional, Union
-
-import torch
-
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
-from colossalai.tensor.op_wrapper import colo_op_impl
-
-
-def _all_int(my_iter):
- return all(isinstance(i, int) for i in my_iter)
-
-
-def _get_valid_shape(shape):
- if isinstance(shape, list):
- if _all_int(shape):
- return tuple(shape)
- else:
- raise RuntimeError("expects type(int) but finds an other type")
- elif isinstance(shape, tuple):
- if _all_int(shape):
- return shape
- else:
- return _get_valid_shape(shape[0])
- else:
- raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
-
-
-def _shape_infer(org_sp, tgt_sp):
- cnt = 0
- pos = 0
- for idx, dim in enumerate(tgt_sp):
- if dim < -1:
- raise RuntimeError("invalid shape dimension {}".format(dim))
- elif dim == -1:
- cnt += 1
- pos = idx
-
- if cnt > 1:
- raise RuntimeError("only one dimension can be inferred")
-
- org_prod = reduce(operator.mul, org_sp, 1)
- tgt_prod = reduce(operator.mul, tgt_sp, 1)
-
- if cnt == 0:
- if org_prod != tgt_prod:
- raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
- else:
- return tgt_sp
- elif org_prod % tgt_prod != 0:
- raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
-
- infer_dim = -(org_prod // tgt_prod)
- return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
-
-
-@colo_op_impl(torch.Tensor.view)
-def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
- """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
- Changes the shape of the current tensor.
- """
- assert isinstance(self, ColoTensor)
- # apply original `view` function for replicated colo tensors
- if self.is_replicate():
- return self.view(*shape)
-
- cur_sp = self.size()
- org_sp = self.size_global()
- # parse the passed arguments
- tgt_sp = _get_valid_shape(shape)
- # get the correct shape from inference
- inf_sp = _shape_infer(org_sp, tgt_sp)
-
- if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
- new_shape = (cur_sp[0],) + tgt_sp[1:]
- res = self.view(*new_shape)
- elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
- new_shape = tgt_sp[:-1] + (cur_sp[-1],)
- res = self.view(*new_shape)
- else:
- replicated_t = self.redistribute(dist_spec=ReplicaSpec())
- return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
- spec=ColoTensorSpec(self.get_process_group()))
-
- return ColoTensor.from_torch_tensor(tensor=res,
- spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
-
-
-@colo_op_impl(torch.Tensor.size)
-def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
- size = self.size_global()
- if dim is None:
- return size
- else:
- return size[dim]
diff --git a/colossalai/legacy/nn/layer/base_layer.py b/colossalai/legacy/nn/layer/base_layer.py
index 4a06bdcb7629..01fd9b3e8943 100644
--- a/colossalai/legacy/nn/layer/base_layer.py
+++ b/colossalai/legacy/nn/layer/base_layer.py
@@ -5,8 +5,8 @@
import torch.nn as nn
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
class ParallelLayer(nn.Module):
diff --git a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py
index 0c049cb3f408..7b0481a3f53c 100644
--- a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py
@@ -1,6 +1,6 @@
import torch.nn as nn
-from colossalai.context import ParallelMode, seed
+from colossalai.legacy.context import ParallelMode, seed
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
index 300baf9c12ba..db9dfa3667b4 100644
--- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
@@ -1,7 +1,7 @@
import torch
import torch.distributed as dist
-from colossalai.core import global_context as gpc
+from colossalai.legacy.core import global_context as gpc
try:
import fused_mix_prec_layer_norm_cuda
diff --git a/colossalai/legacy/nn/layer/parallel_1d/_utils.py b/colossalai/legacy/nn/layer/parallel_1d/_utils.py
index fddf4e73db51..15b41e305cba 100644
--- a/colossalai/legacy/nn/layer/parallel_1d/_utils.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/_utils.py
@@ -4,8 +4,8 @@
import torch
import torch.distributed as dist
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from ..utils import divide
diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py
index c0a169c1596f..db7986b8e8e5 100644
--- a/colossalai/legacy/nn/layer/parallel_1d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py
@@ -10,18 +10,18 @@
from torch import Tensor
from torch.nn.parameter import Parameter
-from colossalai.context import ParallelMode, seed
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast
+from colossalai.legacy.context import ParallelMode, seed
+from colossalai.legacy.context.parallel_context import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
-from colossalai.nn import init as init
-from colossalai.utils.checkpointing import (
+from colossalai.legacy.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
+from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from ..base_layer import ParallelLayer
diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
index fa9b49bcf53f..43e14d4a47a5 100644
--- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
@@ -5,10 +5,10 @@
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce, reduce_scatter
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.utils import get_current_device
@@ -31,9 +31,9 @@ def matmul_2d(
out_shape (:class:`torch.size`): shape of output tensor.
row_rank (int, optional): the rank of row, defaults to None.
col_rank (int, optional): the rank of column, defaults to None.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`, optional):
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`, optional):
row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`, optional):
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`, optional):
column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL.
Returns:
@@ -146,8 +146,8 @@ def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int,
out_shape (:class:`torch.size`): shape of output tensor.
row_rank (int, optional): the rank of row, defaults to None.
col_rank (int, optional): the rank of column, defaults to None.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
pipeline_parallel_size (int): pipeline parallel size.
@@ -172,8 +172,8 @@ class Matmul_AB_2D(torch.autograd.Function):
out_shape (:class:`torch.size`): shape of output tensor.
row_rank (int, optional): the rank of row, defaults to None.
col_rank (int, optional): the rank of column, defaults to None.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
pipeline_parallel_size (int): pipeline parallel size.
@@ -299,8 +299,8 @@ class Matmul_ABT_2D(torch.autograd.Function):
out_shape (:class:`torch.size`): shape of output tensor.
row_rank (int, optional): the rank of row, defaults to None.
col_rank (int, optional): the rank of column, defaults to None.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
@@ -433,8 +433,8 @@ class Matmul_ATB_2D(torch.autograd.Function):
out_shape (:class:`torch.size`): shape of output tensor.
row_rank (int, optional): the rank of row, defaults to None.
col_rank (int, optional): the rank of column, defaults to None.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
pipeline_parallel_size (int): pipeline parallel size.
@@ -620,8 +620,8 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro
output_size_per_partition (int): size of output per partition.
row_rank (int, optional): the rank of row, defaults to None.
col_rank (int, optional): the rank of column, defaults to None.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
skip_bias_add (bool):
If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion.
data_parallel_rank (int): data parallel rank.
@@ -685,8 +685,8 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r
E_x (:class:`torch.tensor`): mean.
Var_x (:class:`torch.tensor`): variance.
hidden_size (int): hidden size.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -719,7 +719,7 @@ def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode)
Args:
tensor (:class:`torch.tensor`): Input tensor.
dim (int): Dimension to gather.
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -767,7 +767,7 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
Args:
input_ (:class:`torch.tensor`): Input tensor.
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -795,7 +795,7 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
Args:
tensor (:class:`torch.tensor`): Input tensor.
dim (int): Dimension to reduce.
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
diff --git a/colossalai/legacy/nn/layer/parallel_2d/_utils.py b/colossalai/legacy/nn/layer/parallel_2d/_utils.py
index 012fec41c802..87ba1bf69691 100644
--- a/colossalai/legacy/nn/layer/parallel_2d/_utils.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/_utils.py
@@ -1,6 +1,6 @@
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
def get_summa_dim_from_env() -> int:
diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py
index b458d15c78e7..893bc74b57d9 100644
--- a/colossalai/legacy/nn/layer/parallel_2d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py
@@ -8,13 +8,16 @@
from torch import Tensor
from torch.nn import Parameter
-from colossalai.context import ParallelMode, seed
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication import broadcast
+from colossalai.legacy.context import ParallelMode, seed
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
+from colossalai.legacy.utils.checkpointing import (
+ gather_tensor_parallel_state_dict,
+ partition_tensor_parallel_state_dict,
+)
from colossalai.nn import init as init
-from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
from colossalai.utils.cuda import get_current_device
from ..base_layer import ParallelLayer
diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
index 55defa4a328d..1226162ae399 100644
--- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
@@ -5,9 +5,9 @@
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
@@ -112,8 +112,8 @@ def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: T
out_shape (:class:`torch.size`): shape of output tensor.
row_rank (int): the rank of row.
col_rank (int): the rank of column.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
pipeline_parallel_size (int): pipeline parallel size.
@@ -139,8 +139,8 @@ class Matmul_AB_2p5D(torch.autograd.Function):
row_rank (int): the rank of row.
col_rank (int): the rank of column.
dep_rank (int): the rank of depth.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
pipeline_parallel_size (int): pipeline parallel size.
@@ -264,8 +264,8 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
row_rank (int): the rank of row.
col_rank (int): the rank of column.
dep_rank (int): the rank of depth.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
pipeline_parallel_size (int): pipeline parallel size.
@@ -394,8 +394,8 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
row_rank (int): the rank of row.
col_rank (int): the rank of column.
dep_rank (int): the rank of depth.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
data_parallel_rank (int): data parallel rank.
pipeline_parallel_rank (int): pipeline parallel rank
pipeline_parallel_size (int): pipeline parallel size.
@@ -606,7 +606,7 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t
row_rank (int): the rank of row.
col_rank (int): the rank of column.
dep_rank (int): the rank of depth.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion.
data_parallel_rank (int): data parallel rank.
@@ -631,7 +631,7 @@ class _Layernorm2p5D(torch.autograd.Function):
E_x (:class:`torch.tensor`): mean.
Var_x (:class:`torch.tensor`): variance.
hidden_size (int): hidden size.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -682,7 +682,7 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
E_x (:class:`torch.tensor`): mean.
Var_x (:class:`torch.tensor`): variance.
hidden_size (int): hidden size.
- row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode.
+ row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -715,7 +715,7 @@ def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: Parallel
Args:
inputs (:class:`torch.tensor`): input tensor.
dim (int): dimension of all-gather.
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -730,7 +730,7 @@ class SplitFirst(torch.autograd.Function):
Args:
inputs (:class:`torch.tensor`): input tensor.
tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism
- col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode.
+ col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -798,7 +798,7 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
Args:
input_ (:class:`torch.tensor`): Input tensor.
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -826,7 +826,7 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel
Args:
input_ (:class:`torch.tensor`): Input tensor.
dim (int): Dimension to reduce.
- parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py
index 1478b25de618..69a350a977ac 100644
--- a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py
@@ -1,6 +1,6 @@
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
def get_tesseract_dim_dep_from_env():
diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
index 04acc2bb0f4c..b4aa9f16ddf0 100644
--- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
@@ -8,17 +8,17 @@
from torch import Tensor
from torch.nn import Parameter
-from colossalai.context import ParallelMode, seed
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication import broadcast
+from colossalai.legacy.context import ParallelMode, seed
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
-from colossalai.nn import init as init
-from colossalai.utils.checkpointing import (
+from colossalai.legacy.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
+from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from ..base_layer import ParallelLayer
diff --git a/colossalai/legacy/nn/layer/parallel_3d/_operation.py b/colossalai/legacy/nn/layer/parallel_3d/_operation.py
index ca0b0e62783a..c6374efb7124 100755
--- a/colossalai/legacy/nn/layer/parallel_3d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/_operation.py
@@ -7,10 +7,10 @@
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.legacy.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
+from colossalai.legacy.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from ._utils import get_parallel_mode_from_env, push_async_grad
@@ -73,9 +73,9 @@ def linear_3d(
Args:
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
- input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
- weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
- output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
+ input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode.
+ weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode.
+ output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -166,9 +166,9 @@ def classifier_3d(
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
bias (:class:`torch.tensor`): matrix of bias.
- input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
- weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
- output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
+ input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode.
+ weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode.
+ output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -260,9 +260,9 @@ def vocab_parallel_classifier_3d(
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
bias (:class:`torch.tensor`): matrix of bias.
- input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
- weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
- output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
+ input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode.
+ weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode.
+ output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -378,8 +378,8 @@ def layernorm_3d(
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability
- output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
- input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
+ output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode.
+ input_x_weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input x weight parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -404,7 +404,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
Args:
tensor (:class:`torch.tensor`): Input tensor.
dim (int): Specified dimension in which to split.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): Parallel mode.
Returns:
:class:`torch.tensor`: The tensor has been split.
@@ -434,8 +434,8 @@ def split_batch_3d(input_: Tensor,
Args:
input_ (:class:`torch.tensor`): Input tensor.
dim (int): Specified dimension in which to split.
- input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): input parallel mode.
- weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): weight parallel mode.
+ input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): input parallel mode.
+ weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): weight parallel mode.
Returns:
:class:`torch.tensor`: The tensor has been split.
@@ -471,7 +471,7 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
Args:
tensor (:class:`torch.tensor`): Input tensor.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -501,7 +501,7 @@ def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode)
Args:
tensor (:class:`torch.tensor`): Input tensor.
dim (int): Dimension to gather.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -530,7 +530,7 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
Args:
tensor (:class:`torch.tensor`): Input tensor.
dim (int): Dimension to scatter.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode.
+ parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -578,8 +578,8 @@ def reduce_by_batch_3d(tensor: Tensor,
r"""All-reduce the input from the model parallel region.
Args:
- input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
- weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
+ input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode.
+ weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode.
reduce_mean (bool, optional): If set to ``True``, it will divide the output by
(input parallel size * weight parallel size), default to False.
diff --git a/colossalai/legacy/nn/layer/parallel_3d/_utils.py b/colossalai/legacy/nn/layer/parallel_3d/_utils.py
index 364191a79f88..cb300c2a9684 100644
--- a/colossalai/legacy/nn/layer/parallel_3d/_utils.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/_utils.py
@@ -4,9 +4,15 @@
import torch
from torch import Tensor
-from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.constants import (
+ INPUT_GROUP_3D,
+ INPUT_X_WEIGHT_3D,
+ OUTPUT_GROUP_3D,
+ OUTPUT_X_WEIGHT_3D,
+ WEIGHT_GROUP_3D,
+)
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
def get_depth_from_env() -> int:
diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py
index b815a842ca52..d6aaa427b9e6 100644
--- a/colossalai/legacy/nn/layer/parallel_3d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py
@@ -8,19 +8,25 @@
from torch import Tensor
from torch.nn import Parameter
-from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
-from colossalai.context import ParallelMode, seed
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.communication import all_reduce, broadcast
+from colossalai.legacy.constants import (
+ INPUT_GROUP_3D,
+ INPUT_X_WEIGHT_3D,
+ OUTPUT_GROUP_3D,
+ OUTPUT_X_WEIGHT_3D,
+ WEIGHT_GROUP_3D,
+)
+from colossalai.legacy.context import ParallelMode, seed
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.nn.layer.base_layer import ParallelLayer
from colossalai.legacy.registry import LAYERS
-from colossalai.nn import init as init
-from colossalai.utils.checkpointing import (
+from colossalai.legacy.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
+from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
index fcf2962017a3..ea1863f0b474 100644
--- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
@@ -5,9 +5,9 @@
from torch import distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.legacy.communication import ring_forward
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
from colossalai.utils import get_current_device
diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py
index e44e61c2fb7d..033c1be962ae 100644
--- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py
@@ -9,11 +9,11 @@
from torch.nn import Parameter
import colossalai
-from colossalai.context import seed
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.legacy.context import seed
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
from colossalai.legacy.registry import LAYERS
diff --git a/colossalai/legacy/nn/layer/utils/common.py b/colossalai/legacy/nn/layer/utils/common.py
index d8f3ad2a7eca..3148a0bed570 100644
--- a/colossalai/legacy/nn/layer/utils/common.py
+++ b/colossalai/legacy/nn/layer/utils/common.py
@@ -8,9 +8,9 @@
import torch
from torch import Tensor, nn
-from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
-from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.utils import checkpoint
+from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
+from colossalai.legacy.global_variables import tensor_parallel_env as env
+from colossalai.legacy.utils import checkpoint
class CheckpointModule(nn.Module):
diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py
index 0e11fc4d0dab..71ca1d421de6 100644
--- a/colossalai/legacy/nn/layer/vanilla/layers.py
+++ b/colossalai/legacy/nn/layer/vanilla/layers.py
@@ -7,7 +7,7 @@
from torch import nn as nn
from torch.nn.parameter import Parameter
-from colossalai.context import seed
+from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device
@@ -64,7 +64,7 @@ class WrappedDropout(nn.Module):
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
- mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
@@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module):
Args:
p (float, optional): probability of dropping path, defaults 0.0.
- mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
+ mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
diff --git a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
index 68fea8622c5c..ec19d1b707d8 100644
--- a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
+++ b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
@@ -3,8 +3,8 @@
import torch.distributed as dist
import torch.nn as nn
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
class PipelineSharedModuleWrapper:
diff --git a/colossalai/legacy/nn/loss/__init__.py b/colossalai/legacy/nn/loss/__init__.py
index 1bd8872d9c3a..abb7ec3ef824 100644
--- a/colossalai/legacy/nn/loss/__init__.py
+++ b/colossalai/legacy/nn/loss/__init__.py
@@ -2,7 +2,7 @@
from torch.nn.modules.loss import *
from torch.nn.modules.loss import _Loss
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
from .loss_1d import VocabParallelCrossEntropyLoss1D
diff --git a/colossalai/legacy/nn/loss/loss_1d.py b/colossalai/legacy/nn/loss/loss_1d.py
index 8c9483fccaec..2582e8b359d5 100644
--- a/colossalai/legacy/nn/loss/loss_1d.py
+++ b/colossalai/legacy/nn/loss/loss_1d.py
@@ -3,8 +3,8 @@
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.modules.loss import _Loss
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py
index 6191602b71ee..7ab58415608a 100644
--- a/colossalai/legacy/nn/loss/loss_2d.py
+++ b/colossalai/legacy/nn/loss/loss_2d.py
@@ -4,8 +4,8 @@
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.legacy.registry import LOSSES
diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py
index 2746b201152c..8a5d04a8c788 100644
--- a/colossalai/legacy/nn/loss/loss_2p5d.py
+++ b/colossalai/legacy/nn/loss/loss_2p5d.py
@@ -4,8 +4,8 @@
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.legacy.registry import LOSSES
diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py
index 2aeb1bd9825d..a576d84f71cd 100644
--- a/colossalai/legacy/nn/loss/loss_3d.py
+++ b/colossalai/legacy/nn/loss/loss_3d.py
@@ -4,8 +4,8 @@
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
-from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
-from colossalai.core import global_context as gpc
+from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.legacy.registry import LOSSES
diff --git a/colossalai/legacy/nn/metric/accuracy_3d.py b/colossalai/legacy/nn/metric/accuracy_3d.py
index 1aaac73ecabd..675f5c2b5120 100644
--- a/colossalai/legacy/nn/metric/accuracy_3d.py
+++ b/colossalai/legacy/nn/metric/accuracy_3d.py
@@ -1,7 +1,7 @@
import torch
from torch import nn
-from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.legacy.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
diff --git a/colossalai/legacy/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py
index f839d6b28444..2b2ad36a74f4 100644
--- a/colossalai/legacy/nn/parallel/data_parallel.py
+++ b/colossalai/legacy/nn/parallel/data_parallel.py
@@ -5,7 +5,7 @@
import torch
import torch.distributed as dist
-from colossalai.tensor import ProcessGroup as ColoProcessGroup
+from colossalai.legacy.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import is_ddp_ignored
from .reducer import Reducer
@@ -34,8 +34,8 @@ class ColoDDP(torch.nn.Module):
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
Example:
- >>> from colossalai.core import global_context as gpc
- >>> from colossalai.context import ParallelMode
+ >>> from colossalai.legacy.core import global_context as gpc
+ >>> from colossalai.legacy.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>> pg = ProcessGroup(tp_degree = world_size//2)
>>> model = ColoDDP(model, pg)
diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
index 79d7672b26bc..522fb4f4497f 100644
--- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
@@ -4,7 +4,8 @@
import torch.nn.functional as F
from colossalai.legacy.nn._ops._utils import dual_all_to_all
-from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
+from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
+from colossalai.tensor import ColoParameter, ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
index 116d836b7139..a1feda2bdb0e 100644
--- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
@@ -6,7 +6,7 @@
import torch.nn.functional as F
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
-from colossalai.tensor import ProcessGroup
+from colossalai.legacy.tensor import ProcessGroup
from .cache_mgr import EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
index 0014c784fba1..8017ee72b0b4 100644
--- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
@@ -7,7 +7,7 @@
from torch.profiler import record_function
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
-from colossalai.tensor import ProcessGroup
+from colossalai.legacy.tensor import ProcessGroup
from .cache_mgr import EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
diff --git a/colossalai/legacy/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py
index a0a3eb40cf08..69d92afaaa94 100644
--- a/colossalai/legacy/nn/parallel/layers/colo_module.py
+++ b/colossalai/legacy/nn/parallel/layers/colo_module.py
@@ -1,7 +1,7 @@
from typing import Dict, List
-from colossalai.tensor import ComputePattern
-from colossalai.tensor.distspec import _DistSpec
+from colossalai.legacy.tensor import ComputePattern
+from colossalai.legacy.tensor.distspec import _DistSpec
class ColoModule(object):
diff --git a/colossalai/legacy/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py
index 3e4e7ffd8de7..4796699fc57f 100644
--- a/colossalai/legacy/nn/parallel/layers/embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/embedding.py
@@ -1,4 +1,4 @@
-from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
+from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from .colo_module import ColoModule
diff --git a/colossalai/legacy/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py
index e391cf808933..51a8d4c976a6 100644
--- a/colossalai/legacy/nn/parallel/layers/linear.py
+++ b/colossalai/legacy/nn/parallel/layers/linear.py
@@ -1,4 +1,4 @@
-from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
+from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from .colo_module import ColoModule
diff --git a/colossalai/legacy/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py
index 191266fa70fd..09326d2d6f9a 100644
--- a/colossalai/legacy/nn/parallel/layers/module_utils.py
+++ b/colossalai/legacy/nn/parallel/layers/module_utils.py
@@ -2,7 +2,8 @@
import torch
-from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec
+from colossalai.legacy.tensor import ComputeSpec, ProcessGroup, distspec
+from colossalai.tensor import ColoParameter
from . import ColoModule
diff --git a/colossalai/legacy/pipeline/__init__.py b/colossalai/legacy/pipeline/__init__.py
new file mode 100644
index 000000000000..f36f54ac9307
--- /dev/null
+++ b/colossalai/legacy/pipeline/__init__.py
@@ -0,0 +1,4 @@
+from .layer_spec import LayerSpec
+from .pipelinable import PipelinableContext, PipelinableModel
+
+__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
diff --git a/colossalai/pipeline/layer_spec.py b/colossalai/legacy/pipeline/layer_spec.py
similarity index 97%
rename from colossalai/pipeline/layer_spec.py
rename to colossalai/legacy/pipeline/layer_spec.py
index 7e9169efff78..3960debd7f72 100644
--- a/colossalai/pipeline/layer_spec.py
+++ b/colossalai/legacy/pipeline/layer_spec.py
@@ -1,9 +1,11 @@
import torch
+
from colossalai.utils.model.utils import call_to_str
+
class LayerSpec:
"""
-
+
"""
def __init__(self, typename, *module_args, **module_kwargs):
@@ -52,4 +54,4 @@ def count_params(self):
return self._param_count
def reset_param_count(self):
- self._param_count = 0
\ No newline at end of file
+ self._param_count = 0
diff --git a/colossalai/legacy/pipeline/middleware/__init__.py b/colossalai/legacy/pipeline/middleware/__init__.py
new file mode 100644
index 000000000000..481741bfee31
--- /dev/null
+++ b/colossalai/legacy/pipeline/middleware/__init__.py
@@ -0,0 +1,3 @@
+from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
+
+__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py
similarity index 62%
rename from colossalai/pipeline/middleware/adaptor/__init__.py
rename to colossalai/legacy/pipeline/middleware/adaptor/__init__.py
index 949700a2c49d..0b0d36d2ffe5 100644
--- a/colossalai/pipeline/middleware/adaptor/__init__.py
+++ b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py
@@ -1,3 +1,3 @@
from .fx import get_topology as get_fx_topology
-__all__ = ['get_fx_topology']
\ No newline at end of file
+__all__ = ['get_fx_topology']
diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/legacy/pipeline/middleware/adaptor/fx.py
similarity index 92%
rename from colossalai/pipeline/middleware/adaptor/fx.py
rename to colossalai/legacy/pipeline/middleware/adaptor/fx.py
index 8437c5194762..8cc40f120f15 100644
--- a/colossalai/pipeline/middleware/adaptor/fx.py
+++ b/colossalai/legacy/pipeline/middleware/adaptor/fx.py
@@ -1,6 +1,8 @@
-from torch.fx.graph_module import GraphModule
-from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
import torch
+from torch.fx.graph_module import GraphModule
+
+from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
+
def partition_name_to_id(partition_name, is_input=False, is_output=False):
if is_input:
@@ -12,6 +14,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
+
# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
@@ -20,6 +23,8 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)
+
+
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
@@ -45,9 +50,10 @@ def find_input_in_partition(node, partitions, input_partitions=None):
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
return p_input_val
-
+
return p_input_val
-
+
+
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
@@ -70,7 +76,7 @@ def find_output_in_partition(node, partitions, output_partitions=None):
if arg == user:
p_output_val.add(partition_id=partition_id, offset=i)
break
-
+
# user is output
if output_partitions is not None:
output_node = output_partitions[0]
@@ -84,10 +90,11 @@ def find_output_in_partition(node, partitions, output_partitions=None):
break
return p_output_val
+
def get_topology(gm: GraphModule):
topo = Topo()
topo_output_partition = Partition()
-
+
input_partitions = []
partitions = []
output_partitions = []
@@ -109,7 +116,7 @@ def get_topology(gm: GraphModule):
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition_id(partition_id=0)
-
+
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
# set input for submodule
@@ -131,15 +138,16 @@ def get_topology(gm: GraphModule):
for user in partition.users:
cur_node = user
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
- topo_mid_partition.add_output_val(p_output_val)
- topo.set_partitions(partition_id=i+2, partition=topo_mid_partition)
-
+ topo_mid_partition.add_output_val(p_output_val)
+ topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition)
+
# set input for output_partition
for partition in output_partitions:
topo_output_partition = Partition()
- torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
- find_input_in_partition(n, partitions, input_partitions)))
+ torch.fx.graph.map_arg(
+ partition.args[0],
+ lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)
- return topo
\ No newline at end of file
+ return topo
diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/legacy/pipeline/middleware/topo.py
similarity index 95%
rename from colossalai/pipeline/middleware/topo.py
rename to colossalai/legacy/pipeline/middleware/topo.py
index e798e2ed9cab..3c21cce6dc0e 100644
--- a/colossalai/pipeline/middleware/topo.py
+++ b/colossalai/legacy/pipeline/middleware/topo.py
@@ -1,49 +1,54 @@
-from typing import Dict, List
from dataclasses import dataclass
+from typing import Dict, List
# This file includes data structure used by Pipeline Middleware.
+
@dataclass
class ValPosition:
partition_id: int
offset: int
-
+
def __str__(self) -> str:
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
return res
-
+
def __repr__(self) -> str:
return self.__str__()
+
class PartitionInputVal(object):
+
def __init__(self, partition_id, offset) -> None:
# every input from which partition_id and which offset
val_pos = ValPosition(partition_id, offset)
self._from_partition_and_offset: ValPosition = val_pos
-
+
def get(self):
return self._from_partition_and_offset
-
+
def __str__(self) -> str:
res = ''
res += f'<-({self._from_partition_and_offset})'
return res
-
+
def __repr__(self) -> str:
return self.__str__()
-
+
+
class PartitionOutputVal(object):
+
def __init__(self) -> None:
# every output to which partition_id and which offset
self._to_partition_and_offset: List[ValPosition] = []
-
+
def add(self, partition_id, offset):
val_pos = ValPosition(partition_id, offset)
self._to_partition_and_offset.append(val_pos)
-
+
def get(self):
return self._to_partition_and_offset
-
+
def __str__(self) -> str:
res = ''
res += '->('
@@ -51,27 +56,29 @@ def __str__(self) -> str:
res += f'{val_pos},'
res += ')'
return res
-
+
def __repr__(self) -> str:
return self.__str__()
+
class Partition(object):
+
def __init__(self) -> None:
self._input_vals: List[PartitionInputVal] = []
self._output_vals: List[PartitionOutputVal] = []
-
+
def add_input_val(self, input_val: PartitionInputVal):
self._input_vals.append(input_val)
-
+
def add_output_val(self, output_val: PartitionOutputVal):
self._output_vals.append(output_val)
-
+
def get_input_vals(self):
return self._input_vals
-
+
def get_output_vals(self):
return self._output_vals
-
+
# get the output offsets sent to dst_partition_id
def get_output_offsets(self, dst_partition_id):
res = []
@@ -80,9 +87,9 @@ def get_output_offsets(self, dst_partition_id):
for val_pos in outputs:
if val_pos.partition_id == dst_partition_id:
res.append(offset)
-
+
return res
-
+
# get all input dst partition_ids
def get_input_partition_ids(self):
res = []
@@ -91,7 +98,7 @@ def get_input_partition_ids(self):
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
-
+
# get all output dst partition_ids
def get_output_partition_ids(self):
res = []
@@ -101,24 +108,25 @@ def get_output_partition_ids(self):
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
-
+
def __str__(self) -> str:
res = ''
res += f' input:\n'
res += f' length:{len(self._input_vals)}\n'
for i, input_val in enumerate(self._input_vals):
res += f' offset={i}:{input_val}\n'
-
+
res += f' output:\n'
res += f' length:{len(self._output_vals)}\n'
for i, output_val in enumerate(self._output_vals):
res += f' offset={i}:{output_val}\n'
-
+
return res
-
+
def __repr__(self) -> str:
return self.__str__()
+
# This class is a middleware between partition splitter
# and Pipeline Scheduler. It records the graph info about
# partition input/output and provides it to scheduler.
@@ -132,42 +140,43 @@ def __repr__(self) -> str:
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class Topo(object):
+
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
self._partitions: Dict[int, Partition] = {}
self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id
-
+
def set_input_partition_id(self, partition_id: int):
self._input_partition_id = partition_id
-
+
def set_output_partition_id(self, partition_id: int):
self._output_partition_id = partition_id
-
+
def get_input_partition_id(self):
return self._input_partition_id
-
+
def get_output_partition_id(self):
return self._output_partition_id
-
+
def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition
-
+
def get_mid_partitions(self):
- res = {} #{partition_id: Partition}
+ res = {} #{partition_id: Partition}
for partition_id, partition in self._partitions.items():
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
continue
res[partition_id] = partition
return res
-
+
def get_mid_partition_ids(self):
return list(self.get_mid_partitions().keys())
-
+
def get_input_partition(self):
if self._input_partition_id is not None:
return self._partitions[self._input_partition_id]
return None
-
+
def get_output_partition(self):
if self._output_partition_id is not None:
return self._partitions[self._output_partition_id]
@@ -175,7 +184,7 @@ def get_output_partition(self):
def get_partition_by_id(self, partition_id):
return self._partitions[partition_id]
-
+
def __str__(self) -> str:
res = ''
if len(self._partitions) == 0:
@@ -186,21 +195,20 @@ def __str__(self) -> str:
res += '{\n'
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
res += '}\n'
-
+
mid_parts = self.get_mid_partitions()
for i, (partition_id, part) in enumerate(mid_parts.items()):
res += '{\n'
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
res += '}\n'
-
+
output_part = self.get_output_partition()
if output_part is not None:
res += '{\n'
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
res += '}\n'
-
+
return res
-
+
def __repr__(self) -> str:
return self.__str__()
-
\ No newline at end of file
diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/legacy/pipeline/pipelinable.py
similarity index 93%
rename from colossalai/pipeline/pipelinable.py
rename to colossalai/legacy/pipeline/pipelinable.py
index ba8b1591da9d..e74cad0ad1b0 100644
--- a/colossalai/pipeline/pipelinable.py
+++ b/colossalai/legacy/pipeline/pipelinable.py
@@ -1,20 +1,16 @@
-import inspect
-
import torch
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .layer_spec import LayerSpec
from .utils import (
- build_kwargs_for_function,
build_kwargs_for_module,
call_module,
customized_partition,
- exec_func_with_kwargs,
exec_funcs_with_kwargs,
partition_balanced,
partition_uniform,
@@ -135,8 +131,10 @@ def to_layer_list(self, exec_seq=None):
children_name = []
for child in self._root_children:
layer_spec = self._layer_spec_dict[id(child)]
- if layer_spec.typename in (torch.nn.modules.container.ModuleList,
- torch.nn.modules.container.Sequential):
+ if layer_spec.typename in (
+ torch.nn.modules.container.ModuleList,
+ torch.nn.modules.container.Sequential,
+ ):
for child_in_container in layer_spec.children:
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
for name, module in self._model.named_modules():
@@ -155,9 +153,11 @@ def to_layer_list(self, exec_seq=None):
named_modules = dict(self._model.named_modules())
for index, element in enumerate(exec_seq):
if isinstance(element, str):
- if element == 'SPLIT_NODE':
+ if element == "SPLIT_NODE":
continue
- assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
+ assert (
+ element in named_modules
+ ), f"Found invalid module name {element}, please check if you spell the module name correctly."
# get the layer spec based on the module ID
module = named_modules[element]
@@ -198,11 +198,12 @@ def partition(self, num_chunks, pipeline_size, rank):
param_counts.append(layer_spec.count_params())
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
elif self._policy == "customized":
- assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.'
+ assert (self._exec_seq
+ is not None), f"An explicit exec_seq must be defined by user in customized policy mode."
self.customized_parts = customized_partition(self._exec_seq)
assert len(self.customized_parts) == gpc.get_world_size(
ParallelMode.PIPELINE
- ), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}'
+ ), f"World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}"
parts = self.customized_parts[rank]
else:
raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].")
@@ -241,7 +242,6 @@ def __init__(self, module_list, front_func_dict, behind_func_dict):
def forward(self, *input_tensor, **kwargs):
for module in self._module_list:
-
if id(module) in self._front_func_dict:
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/legacy/pipeline/pipeline_process_group.py
similarity index 98%
rename from colossalai/pipeline/pipeline_process_group.py
rename to colossalai/legacy/pipeline/pipeline_process_group.py
index c61d97ebabfa..1168158defaf 100644
--- a/colossalai/pipeline/pipeline_process_group.py
+++ b/colossalai/legacy/pipeline/pipeline_process_group.py
@@ -1,11 +1,11 @@
-from typing import List, Dict, Tuple
import os
import threading
+from typing import Dict, List, Tuple
-from torch.distributed import rpc
import torch.distributed as dist
+from torch.distributed import rpc
-from colossalai.tensor import ProcessGroup
+from colossalai.legacy.tensor import ProcessGroup
class PipelineProcessGroup:
diff --git a/colossalai/legacy/pipeline/rpc/__init__.py b/colossalai/legacy/pipeline/rpc/__init__.py
new file mode 100644
index 000000000000..15b65a4138a8
--- /dev/null
+++ b/colossalai/legacy/pipeline/rpc/__init__.py
@@ -0,0 +1,4 @@
+from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine
+from .utils import pytree_map
+
+__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/legacy/pipeline/rpc/_pipeline_base.py
similarity index 99%
rename from colossalai/pipeline/rpc/_pipeline_base.py
rename to colossalai/legacy/pipeline/rpc/_pipeline_base.py
index 9e549df58214..88ddb9e98eb2 100644
--- a/colossalai/pipeline/rpc/_pipeline_base.py
+++ b/colossalai/legacy/pipeline/rpc/_pipeline_base.py
@@ -12,9 +12,9 @@
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
-from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
-from colossalai.pipeline.pipeline_process_group import ppg
-from colossalai.pipeline.rpc.utils import (
+from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
+from colossalai.legacy.pipeline.pipeline_process_group import ppg
+from colossalai.legacy.pipeline.rpc.utils import (
get_batch_lengths,
pyobj_map,
pytree_filter,
diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py
similarity index 97%
rename from colossalai/pipeline/rpc/_pipeline_schedule.py
rename to colossalai/legacy/pipeline/rpc/_pipeline_schedule.py
index 6eda8f3b34b7..f53a4835edf2 100644
--- a/colossalai/pipeline/rpc/_pipeline_schedule.py
+++ b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py
@@ -6,8 +6,8 @@
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
-from colossalai.pipeline.pipeline_process_group import ppg
-from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
+from colossalai.legacy.pipeline.pipeline_process_group import ppg
+from colossalai.legacy.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
# Implementation of different Pipeline schedule
# Worker defines the worker for each stage
@@ -78,7 +78,7 @@ def _get_work_item_key(self) -> UniqueKey:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
- target_key.phase == Phase.FORWARD:
+ target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
outstanding_min = actual_stage_num - pp_rank - 1
@@ -144,7 +144,7 @@ def _get_work_item_key(self) -> UniqueKey:
forward_block_num = self.forward_times // forward_block_size
if self.forward_times >= real_microbatch_num or \
- ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
+ ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else: # others
diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/legacy/pipeline/rpc/utils.py
similarity index 98%
rename from colossalai/pipeline/rpc/utils.py
rename to colossalai/legacy/pipeline/rpc/utils.py
index 06e6d976d771..d1033fbde920 100644
--- a/colossalai/pipeline/rpc/utils.py
+++ b/colossalai/legacy/pipeline/rpc/utils.py
@@ -10,7 +10,7 @@
from torch.futures import Future
from colossalai.initialize import launch
-from colossalai.pipeline.pipeline_process_group import ppg
+from colossalai.legacy.pipeline.pipeline_process_group import ppg
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
diff --git a/colossalai/pipeline/utils.py b/colossalai/legacy/pipeline/utils.py
similarity index 100%
rename from colossalai/pipeline/utils.py
rename to colossalai/legacy/pipeline/utils.py
diff --git a/colossalai/legacy/tensor/__init__.py b/colossalai/legacy/tensor/__init__.py
new file mode 100644
index 000000000000..d3278bf1e420
--- /dev/null
+++ b/colossalai/legacy/tensor/__init__.py
@@ -0,0 +1,17 @@
+from . import distspec
+from .compute_spec import ComputePattern, ComputeSpec
+from .dist_spec_mgr import DistSpecManager
+from .distspec import ReplicaSpec, ShardSpec
+from .process_group import ProcessGroup
+from .tensor_spec import ColoTensorSpec
+
+__all__ = [
+ 'ComputePattern',
+ 'ComputeSpec',
+ 'distspec',
+ 'DistSpecManager',
+ 'ProcessGroup',
+ 'ColoTensorSpec',
+ 'ShardSpec',
+ 'ReplicaSpec',
+]
diff --git a/colossalai/tensor/compute_spec.py b/colossalai/legacy/tensor/compute_spec.py
similarity index 100%
rename from colossalai/tensor/compute_spec.py
rename to colossalai/legacy/tensor/compute_spec.py
diff --git a/colossalai/tensor/const.py b/colossalai/legacy/tensor/const.py
similarity index 100%
rename from colossalai/tensor/const.py
rename to colossalai/legacy/tensor/const.py
diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/legacy/tensor/dist_spec_mgr.py
similarity index 97%
rename from colossalai/tensor/dist_spec_mgr.py
rename to colossalai/legacy/tensor/dist_spec_mgr.py
index 4740a316b7f5..d97308b04bef 100644
--- a/colossalai/tensor/dist_spec_mgr.py
+++ b/colossalai/legacy/tensor/dist_spec_mgr.py
@@ -4,12 +4,12 @@
import torch.distributed as dist
from numpy import prod
-from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
-from colossalai.tensor.process_group import ProcessGroup
+from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
+from colossalai.legacy.tensor.process_group import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
-# colossalai.tensor shall not import any submodule from colossal.nn
+# colossalai.legacy.tensor shall not import any submodule from colossal.nn
def divide(numerator, denominator):
"""Only allow exact division.
diff --git a/colossalai/tensor/distspec.py b/colossalai/legacy/tensor/distspec.py
similarity index 100%
rename from colossalai/tensor/distspec.py
rename to colossalai/legacy/tensor/distspec.py
diff --git a/colossalai/tensor/op_wrapper.py b/colossalai/legacy/tensor/op_wrapper.py
similarity index 97%
rename from colossalai/tensor/op_wrapper.py
rename to colossalai/legacy/tensor/op_wrapper.py
index 1c00066f7465..63ebaa264279 100644
--- a/colossalai/tensor/op_wrapper.py
+++ b/colossalai/legacy/tensor/op_wrapper.py
@@ -1,8 +1,5 @@
-from typing import (
- Callable,
- Dict,
-)
import functools
+from typing import Callable, Dict
# Custom sharded ops
_COLOSSAL_OPS: Dict[str, Callable] = {}
diff --git a/colossalai/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py
similarity index 100%
rename from colossalai/tensor/process_group.py
rename to colossalai/legacy/tensor/process_group.py
diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py
similarity index 79%
rename from colossalai/tensor/tensor_spec.py
rename to colossalai/legacy/tensor/tensor_spec.py
index 580df9f8f310..aa792e507639 100644
--- a/colossalai/tensor/tensor_spec.py
+++ b/colossalai/legacy/tensor/tensor_spec.py
@@ -1,8 +1,8 @@
from dataclasses import dataclass
from typing import Optional
-from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
-from colossalai.tensor.process_group import ProcessGroup
+from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
+from colossalai.legacy.tensor.process_group import ProcessGroup
from .compute_spec import ComputeSpec
diff --git a/colossalai/legacy/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py
index 1847e56222a1..1cb99fcc90ed 100644
--- a/colossalai/legacy/trainer/_trainer.py
+++ b/colossalai/legacy/trainer/_trainer.py
@@ -6,8 +6,9 @@
from colossalai.legacy.engine import Engine
from colossalai.legacy.trainer.hooks import BaseHook
+from colossalai.legacy.utils import is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
from colossalai.logging import DistributedLogger
-from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
+from colossalai.utils import MultiTimer
class Trainer:
diff --git a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
index 6b150d29139f..cda10030bf65 100644
--- a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
+++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
@@ -4,8 +4,8 @@
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks import BaseHook
+from colossalai.legacy.utils.checkpointing import save_checkpoint
from colossalai.logging import get_dist_logger
-from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook
diff --git a/colossalai/legacy/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py
index 7d9ad19aa9e9..b1a398ce7f71 100644
--- a/colossalai/legacy/trainer/hooks/_log_hook.py
+++ b/colossalai/legacy/trainer/hooks/_log_hook.py
@@ -5,12 +5,13 @@
import os.path as osp
from typing import List
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
+from colossalai.legacy.utils import is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
from colossalai.logging import DistributedLogger
-from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
+from colossalai.utils import MultiTimer
from ._base_hook import BaseHook
from ._commons_ import _format_number
@@ -112,8 +113,8 @@ class TensorboardHook(BaseHook):
Args:
log_dir (str): Directory of log.
ranks (list): Ranks of processors.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer,
- defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL.
+ parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer,
+ defaults to colossalai.legacy.context.parallel_mode.ParallelMode.GLOBAL.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py
index f1bd19387cb5..899e4d08a5c9 100644
--- a/colossalai/legacy/trainer/hooks/_metric_hook.py
+++ b/colossalai/legacy/trainer/hooks/_metric_hook.py
@@ -7,11 +7,12 @@
import torch
import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.legacy.communication import all_reduce
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import HOOKS
-from colossalai.utils import get_current_device, is_no_pp_or_last_stage
+from colossalai.legacy.utils import is_no_pp_or_last_stage
+from colossalai.utils import get_current_device
from ._base_hook import BaseHook
from ._commons_ import _format_number
diff --git a/colossalai/legacy/utils/__init__.py b/colossalai/legacy/utils/__init__.py
new file mode 100644
index 000000000000..ae358f8bebcb
--- /dev/null
+++ b/colossalai/legacy/utils/__init__.py
@@ -0,0 +1,53 @@
+from .checkpointing import load_checkpoint, save_checkpoint
+from .common import (
+ clip_grad_norm_fp32,
+ copy_tensor_parallel_attributes,
+ count_zeros_fp32,
+ is_dp_rank_0,
+ is_model_parallel_parameter,
+ is_no_pp_or_last_stage,
+ is_tp_rank_0,
+ is_using_ddp,
+ is_using_pp,
+ is_using_sequence,
+ param_is_not_tensor_parallel_duplicate,
+ print_rank_0,
+ switch_virtual_pipeline_parallel_rank,
+ sync_model_param,
+)
+from .data_sampler import DataParallelSampler, get_dataloader
+from .memory import (
+ colo_device_memory_capacity,
+ colo_device_memory_used,
+ colo_get_cpu_memory_capacity,
+ colo_set_cpu_memory_capacity,
+ colo_set_process_memory_fraction,
+ report_memory_usage,
+)
+
+__all__ = [
+ 'DataParallelSampler',
+ 'get_dataloader',
+ 'save_checkpoint',
+ 'load_checkpoint',
+ 'colo_device_memory_capacity',
+ 'colo_device_memory_used',
+ 'colo_get_cpu_memory_capacity',
+ 'colo_set_cpu_memory_capacity',
+ 'colo_set_process_memory_fraction',
+ 'report_memory_usage',
+ 'clip_grad_norm_fp32',
+ 'copy_tensor_parallel_attributes',
+ 'count_zeros_fp32',
+ 'is_dp_rank_0',
+ 'is_model_parallel_parameter',
+ 'is_no_pp_or_last_stage',
+ 'is_tp_rank_0',
+ 'is_using_ddp',
+ 'is_using_pp',
+ 'is_using_sequence',
+ 'param_is_not_tensor_parallel_duplicate',
+ 'print_rank_0',
+ 'switch_virtual_pipeline_parallel_rank',
+ 'sync_model_param',
+]
diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py
similarity index 95%
rename from colossalai/utils/activation_checkpoint.py
rename to colossalai/legacy/utils/activation_checkpoint.py
index fa9ed827a8a7..add690f28cc0 100644
--- a/colossalai/utils/activation_checkpoint.py
+++ b/colossalai/legacy/utils/activation_checkpoint.py
@@ -1,13 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import weakref
+
import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
-from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
-from .cuda import get_current_device
-
-import weakref
+from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
+from colossalai.utils import get_current_device
def copy_to_device(obj, device):
@@ -143,7 +143,7 @@ def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
Args:
function: Describe the forward pass function. It should know how to handle the input tuples.
- activation_offload: The variable to check whether we should offload activation to cpu
+ activation_offload: The variable to check whether we should offload activation to cpu
args (list): Tuple containing the parameters of the function
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
might be more flexibility for user to define there checkpoint function
@@ -227,12 +227,12 @@ def inner_unpack(packed):
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), \
- torch.cuda.amp.autocast(), \
- torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
+ torch.cuda.amp.autocast(), \
+ torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
else:
with torch.enable_grad(), \
- torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
+ torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
if x not in storage:
diff --git a/colossalai/legacy/utils/checkpoint/__init__.py b/colossalai/legacy/utils/checkpoint/__init__.py
new file mode 100644
index 000000000000..558a956b31ac
--- /dev/null
+++ b/colossalai/legacy/utils/checkpoint/__init__.py
@@ -0,0 +1,3 @@
+from .module_checkpoint import load_checkpoint, save_checkpoint
+
+__all__ = ['save_checkpoint', 'load_checkpoint']
diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/legacy/utils/checkpoint/module_checkpoint.py
similarity index 90%
rename from colossalai/utils/checkpoint/module_checkpoint.py
rename to colossalai/legacy/utils/checkpoint/module_checkpoint.py
index d390da864cd3..9bd2907abf9d 100644
--- a/colossalai/utils/checkpoint/module_checkpoint.py
+++ b/colossalai/legacy/utils/checkpoint/module_checkpoint.py
@@ -1,25 +1,28 @@
+from typing import Dict, Optional
+
import torch
import torch.distributed as dist
+
+from colossalai.interface import OptimizerWrapper
from colossalai.tensor import ColoTensor
-from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
-from typing import Optional, Dict
+
+from .utils import gather_tensor, scatter_tensor
def save_checkpoint(path: str,
epoch: int,
model: torch.nn.Module,
- optimizer: Optional[ColossalaiOptimizer] = None,
+ optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
- """save_checkpoint
+ """save_checkpoint
save a model, whose parameters are `ColoTensor`s.
Args:
path (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
- optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
+ optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
rank = dist.get_rank()
@@ -74,17 +77,17 @@ def save_checkpoint(path: str,
def load_checkpoint(path: str,
epoch: int,
model: torch.nn.Module,
- optimizer: Optional[ColossalaiOptimizer] = None,
+ optimizer: Optional[OptimizerWrapper] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
torch_load_kwargs: Optional[Dict] = None,
load_state_dict_kwargs: Optional[Dict] = None):
- """load_checkpoint
+ """load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
path (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
- optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
+ optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/legacy/utils/checkpoint/utils.py
similarity index 91%
rename from colossalai/utils/checkpoint/utils.py
rename to colossalai/legacy/utils/checkpoint/utils.py
index 682cd0903d5b..c830d4811463 100644
--- a/colossalai/utils/checkpoint/utils.py
+++ b/colossalai/legacy/utils/checkpoint/utils.py
@@ -1,63 +1,65 @@
-import torch
-import torch.distributed as dist
-from colossalai.tensor import ColoTensor, ColoTensorSpec
-from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
-
-
-def robust_broadcast(tensor):
- with torch.no_grad():
- is_cpu_ten = tensor.device.type == 'cpu'
- if is_cpu_ten:
- b_data = tensor.cuda()
- else:
- b_data = tensor
-
- dist.broadcast(b_data, 0)
-
- if is_cpu_ten:
- tensor.copy_(b_data)
-
-
-def gather_tensor(colo_tensor: ColoTensor) -> None:
- """Make colo_tensor replicated when the rank is 0
- """
- if not colo_tensor.is_replicate():
- pg = colo_tensor.get_process_group()
- # for the group which contains rank 0
- if pg.dp_local_rank() == 0:
- old_dist_spec = colo_tensor.dist_spec
- colo_tensor.to_replicate_()
- if dist.get_rank() != 0:
- colo_tensor.set_dist_spec(old_dist_spec)
-
- # synchronize all processes for unexpected problems
- dist.barrier()
-
- if dist.get_rank() == 0:
- setattr(colo_tensor, 'save_ready', True) # set saving signature
-
-
-def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
- """Reversal operation of `gather_tensor`.
- """
- if dist_spec.placement == DistPlacementPattern.REPLICATE:
- robust_broadcast(colo_tensor.data)
- else:
- global_size = colo_tensor.size_global()
-
- if dist.get_rank() == 0:
- entire_data = colo_tensor.data
- else:
- entire_data = torch.empty(global_size, device=colo_tensor.device)
- robust_broadcast(entire_data)
-
- if dist.get_rank() == 0:
- colo_tensor.set_dist_spec(dist_spec)
- else:
- rep_tensor = ColoTensor(
- entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
- rep_tensor.set_dist_spec(dist_spec)
- with torch.no_grad():
- colo_tensor.data.copy_(rep_tensor.data)
- # synchronize all processes for unexpected problems
- dist.barrier()
+import torch
+import torch.distributed as dist
+
+from colossalai.legacy.tensor import ColoTensorSpec
+from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
+from colossalai.tensor import ColoTensor
+
+
+def robust_broadcast(tensor):
+ with torch.no_grad():
+ is_cpu_ten = tensor.device.type == 'cpu'
+ if is_cpu_ten:
+ b_data = tensor.cuda()
+ else:
+ b_data = tensor
+
+ dist.broadcast(b_data, 0)
+
+ if is_cpu_ten:
+ tensor.copy_(b_data)
+
+
+def gather_tensor(colo_tensor: ColoTensor) -> None:
+ """Make colo_tensor replicated when the rank is 0
+ """
+ if not colo_tensor.is_replicate():
+ pg = colo_tensor.get_process_group()
+ # for the group which contains rank 0
+ if pg.dp_local_rank() == 0:
+ old_dist_spec = colo_tensor.dist_spec
+ colo_tensor.to_replicate_()
+ if dist.get_rank() != 0:
+ colo_tensor.set_dist_spec(old_dist_spec)
+
+ # synchronize all processes for unexpected problems
+ dist.barrier()
+
+ if dist.get_rank() == 0:
+ setattr(colo_tensor, 'save_ready', True) # set saving signature
+
+
+def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
+ """Reversal operation of `gather_tensor`.
+ """
+ if dist_spec.placement == DistPlacementPattern.REPLICATE:
+ robust_broadcast(colo_tensor.data)
+ else:
+ global_size = colo_tensor.size_global()
+
+ if dist.get_rank() == 0:
+ entire_data = colo_tensor.data
+ else:
+ entire_data = torch.empty(global_size, device=colo_tensor.device)
+ robust_broadcast(entire_data)
+
+ if dist.get_rank() == 0:
+ colo_tensor.set_dist_spec(dist_spec)
+ else:
+ rep_tensor = ColoTensor(
+ entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
+ rep_tensor.set_dist_spec(dist_spec)
+ with torch.no_grad():
+ colo_tensor.data.copy_(rep_tensor.data)
+ # synchronize all processes for unexpected problems
+ dist.barrier()
diff --git a/colossalai/utils/checkpointing.py b/colossalai/legacy/utils/checkpointing.py
similarity index 98%
rename from colossalai/utils/checkpointing.py
rename to colossalai/legacy/utils/checkpointing.py
index d1c6b6370ede..b7b29cc984d6 100644
--- a/colossalai/utils/checkpointing.py
+++ b/colossalai/legacy/utils/checkpointing.py
@@ -3,9 +3,11 @@
import torch
import torch.distributed as dist
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.constants import IS_TENSOR_PARALLEL
+
+from colossalai.legacy.constants import IS_TENSOR_PARALLEL
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py
new file mode 100644
index 000000000000..35095161c2f2
--- /dev/null
+++ b/colossalai/legacy/utils/common.py
@@ -0,0 +1,434 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+from collections import defaultdict
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch import inf
+from torch.nn.parameter import Parameter
+
+from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
+from colossalai.legacy.tensor import ProcessGroup
+from colossalai.tensor import ColoParameter
+from colossalai.utils.multi_tensor_apply import multi_tensor_applier
+
+try:
+ from colossalai._C import fused_optim
+except:
+ fused_optim = None
+
+
+def print_rank_0(msg: str, logger=None):
+ """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
+
+ Args:
+ msg (str): A string message to output.
+ logger (:class:`colossalai.logging.DistributedLogger`, optional):
+ The logger to record the message, defaults to None.
+ """
+ if gpc.get_global_rank() == 0:
+ if logger is None:
+ print(msg, flush=True)
+ else:
+ logger.info(msg)
+
+
+def sync_model_param(model, parallel_mode):
+ r"""Make sure data parameters are consistent during Data Parallel Mode.
+
+ Args:
+ model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
+ parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel mode to be checked.
+
+ Note:
+ The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
+ in `parallel_mode `_
+ """
+ if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
+ for param in model.parameters():
+ ranks = gpc.get_ranks_in_group(parallel_mode)
+ dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
+
+
+def is_dp_rank_0():
+ return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
+
+
+def is_tp_rank_0():
+ return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
+
+
+def is_no_pp_or_last_stage():
+ return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
+
+
+def is_using_ddp():
+ return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1
+
+
+def is_using_pp():
+ return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
+
+
+def is_using_sequence():
+ return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
+
+
+class model_branch_context(object):
+
+ def __enter__(self):
+ self.env_status = env.save()
+
+ def __exit__(self, *exc_info):
+ env.load(**self.env_status)
+
+
+def is_model_parallel_parameter(p):
+ return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
+
+
+def _calc_l2_norm(grads):
+ # we should not
+ global fused_optim
+
+ if fused_optim is None:
+ from colossalai.kernel.op_builder import FusedOptimBuilder
+ fused_optim = FusedOptimBuilder().load()
+
+ norm = 0.0
+ if len(grads) > 0:
+ dummy_overflow_buf = torch.cuda.IntTensor([0])
+ norm, _ = multi_tensor_applier(
+ fused_optim.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads],
+ False # no per-parameter norm
+ )
+ return norm
+
+
+def _calc_lp(grads, norm_type):
+ norm = 0.0
+ for grad in grads:
+ grad_norm = torch.norm(grad, norm_type)
+ norm += grad_norm**norm_type
+ return norm
+
+
+def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
+ if torch.is_tensor(norm) and norm.device.type != 'cuda':
+ norm = norm.to(torch.cuda.current_device())
+ return norm
+
+
+def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
+ if isinstance(norm, float):
+ norm = torch.Tensor([norm])
+ if move_to_cuda:
+ norm = norm.to(torch.cuda.current_device())
+ return norm
+
+
+# ======== Gradient Clipping =========
+
+
+def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
+ if len(params) == 0:
+ return 0.0
+ grads = [p.grad for p in params]
+ use_cuda_kernel = grads[0].device.type == 'cuda'
+ if norm_type == inf:
+ local_lp = max([g.abs().max() for g in grads])
+ elif norm_type == 2.0 and use_cuda_kernel:
+ local_lp = _calc_l2_norm(grads)**norm_type
+ else:
+ local_lp = _calc_lp(grads, norm_type)
+ if isinstance(local_lp, torch.Tensor):
+ return local_lp.item()
+ return local_lp
+
+
+def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
+ if len(params) == 0:
+ return 0.0
+ buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
+ for p in params:
+ if p.is_replicate():
+ buckets[None].append(p)
+ else:
+ buckets[p.get_process_group().tp_process_group()].append(p)
+ total_lp = 0.0
+ for group, bucket in buckets.items():
+ local_lp = _compute_local_lp(bucket, norm_type)
+ if group is not None:
+ local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
+ if norm_type == inf:
+ dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
+ else:
+ dist.all_reduce(local_lp_tensor, group=group)
+ local_lp = local_lp_tensor.item()
+ if norm_type == inf:
+ total_lp = max(total_lp, local_lp)
+ else:
+ total_lp += local_lp
+ return total_lp
+
+
+def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
+ if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
+ total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
+ if norm_type == inf:
+ dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
+ else:
+ dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
+ total_lp = total_lp_tensor.item()
+ return total_lp
+
+
+def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grad_dtype = None
+ cpu_grad_params: List[ColoParameter] = []
+ cuda_grad_params: List[ColoParameter] = []
+ for p in parameters:
+ if p.grad is None:
+ continue
+ assert isinstance(p, ColoParameter)
+ if grad_dtype is None:
+ grad_dtype = p.grad.dtype
+ assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
+ if p.grad.device.type == 'cuda':
+ cuda_grad_params.append(p)
+ else:
+ cpu_grad_params.append(p)
+ norm_type = float(norm_type)
+ cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
+ cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
+ if norm_type == inf:
+ total_lp = max(cpu_lp, cuda_lp)
+ else:
+ total_lp = cpu_lp + cuda_lp
+ return _compute_pp_grad_lp(total_lp, norm_type)
+
+
+def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
+ norm_type = float(norm_type)
+ total_norm = _compute_grad_lp(parameters, norm_type)
+ if norm_type != inf:
+ total_norm = total_norm**(1 / norm_type)
+ return total_norm
+
+
+def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
+ clip_coef = max_norm / (total_norm + 1e-6)
+ if clip_coef < 1.0:
+ cuda_grads: List[torch.Tensor] = []
+ cpu_grads: List[torch.Tensor] = []
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ for p in parameters:
+ if p.grad is None:
+ continue
+ if p.grad.device.type == 'cuda':
+ cuda_grads.append(p.grad.detach())
+ else:
+ cpu_grads.append(p.grad.detach())
+ if len(cuda_grads) > 0:
+ dummy_overflow_buf = torch.cuda.IntTensor([0])
+ multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads],
+ clip_coef)
+ for g in cpu_grads:
+ g.mul_(clip_coef)
+
+
+def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
+ total_norm = compute_grad_norm(parameters, norm_type)
+ _clip_grad_norm(parameters, max_norm, total_norm)
+ return total_norm
+
+
+def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
+ """Clips gradient norm of an iterable of parameters whose gradients are in fp32.
+
+ This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
+ added functionality to handle model parallel parameters.
+
+ Note:
+ the gradients are modified in place.
+
+ Args:
+ parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`):
+ An iterable of Tensors or a single Tensor that will have gradients normalized.
+ max_norm (Union[float, int]): Max norm of the gradients.
+ norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm.
+
+ Returns:
+ float: Total norm of the parameters.
+ """
+
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+
+ # Filter parameters based on:
+ # - grad should not be none
+ # - parameter should not be shared
+ # - should not be a replica due to tensor model parallelism
+ params: List[Parameter] = []
+ has_zero_shared_param: bool = False
+ for param in parameters:
+ if param.grad is not None:
+ # Make sure the grads are in fp32
+ assert param.grad.dtype == torch.float, \
+ f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
+ if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
+ has_zero_shared_param = True
+ params.append(param)
+
+ if len(params) == 0:
+ enable_cuda_kernels = False
+ else:
+ enable_cuda_kernels = params[0].grad.device.type == 'cuda'
+ # Norm parameters.
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+
+ # Parameters can be on CPU or CUDA
+ # If parameters are on CPU, disable CUDA kernels
+
+ # Calculate norm.
+ if norm_type == inf:
+ total_norm = max(p.grad.data.abs().max() for p in params)
+ total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
+ # Take max across all model-parallel GPUs.
+ if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=gpc.get_group(ParallelMode.MODEL),
+ async_op=False)
+ if has_zero_shared_param:
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=gpc.get_group(ParallelMode.DATA),
+ async_op=False)
+ total_norm = total_norm_cuda[0].item()
+ else:
+ tensor_parallel_grads = []
+ no_tensor_parallel_grads = []
+ zero_sharded_grads = []
+ for p in params:
+ if is_model_parallel_parameter(p):
+ reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
+ tensor_parallel_grads.append(p.grad.data / reductor)
+ elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
+ zero_sharded_grads.append(p.grad.data)
+ else:
+ no_tensor_parallel_grads.append(p.grad.data)
+
+ if norm_type == 2.0 and enable_cuda_kernels:
+ tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
+ no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
+ zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
+ else:
+ tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
+ no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
+ zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
+ # If norm is type of float, then we convert them into torch.Tensor.
+ tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
+ no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels)
+ zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels)
+ # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
+ if not enable_cuda_kernels:
+ tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
+ no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
+ zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
+
+ # Sum across all model-parallel GPUs.
+ if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
+ dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
+ # Sum across all zero sharded GPUs
+ if len(zero_sharded_grads) > 0:
+ dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
+ no_tensor_parallel_norm += zero_sharded_norm
+ total_norm = tensor_parallel_norm + no_tensor_parallel_norm
+ if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
+ dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
+ total_norm = total_norm**(1.0 / norm_type)
+ if torch.is_tensor(total_norm):
+ total_norm = total_norm.item()
+
+ # Scale.
+ clip_coeff = max_norm / (total_norm + 1.0e-6)
+ if clip_coeff < 1.0:
+ if enable_cuda_kernels:
+ grads = [p.grad.detach() for p in params]
+ dummy_overflow_buf = torch.cuda.IntTensor([0])
+ multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
+ else:
+ for p in params:
+ p.grad.detach().mul_(clip_coeff)
+ return total_norm
+
+
+def count_zeros_fp32(parameters):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+
+ # Filter parameters based on:
+ # - grad should not be none
+ # - parameter should not be shared
+ # - should not be a replica due to tensor model parallelism
+ total_num_zeros = 0.0
+ for param in parameters:
+ grad_not_none = param.grad is not None
+ is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
+ if grad_not_none and is_not_tp_duplicate:
+ grad = param.grad.detach()
+ num_zeros = grad.numel() - torch.count_nonzero(grad)
+ total_num_zeros = num_zeros + total_num_zeros
+
+ total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
+
+ # Sum across all model-parallel GPUs.
+ ops = []
+ ops.append(
+ dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
+ if gpc.is_initialized(ParallelMode.PIPELINE):
+ ops.append(
+ dist.all_reduce(total_num_zeros,
+ op=dist.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.PIPELINE),
+ async_op=True))
+
+ for req in ops:
+ req.wait()
+ total_num_zeros = total_num_zeros.item()
+
+ return total_num_zeros
+
+
+def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
+ for attr in TENSOR_PARALLEL_ATTRIBUTES:
+ if hasattr(src_tensor, attr):
+ val = getattr(src_tensor, attr)
+ setattr(dst_tensor, attr, val)
+
+
+def param_is_not_tensor_parallel_duplicate(param):
+ return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
+ ParallelMode.TENSOR) == 0)
+
+
+@contextmanager
+def switch_virtual_pipeline_parallel_rank(rank):
+ prev_rank = gpc.virtual_pipeline_parallel_rank
+ try:
+ gpc.set_virtual_pipeline_parallel_rank(rank)
+ yield
+ finally:
+ gpc.set_virtual_pipeline_parallel_rank(prev_rank)
diff --git a/colossalai/utils/data_sampler/__init__.py b/colossalai/legacy/utils/data_sampler/__init__.py
similarity index 100%
rename from colossalai/utils/data_sampler/__init__.py
rename to colossalai/legacy/utils/data_sampler/__init__.py
diff --git a/colossalai/utils/data_sampler/base_sampler.py b/colossalai/legacy/utils/data_sampler/base_sampler.py
similarity index 100%
rename from colossalai/utils/data_sampler/base_sampler.py
rename to colossalai/legacy/utils/data_sampler/base_sampler.py
diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py
similarity index 98%
rename from colossalai/utils/data_sampler/data_parallel_sampler.py
rename to colossalai/legacy/utils/data_sampler/data_parallel_sampler.py
index 881ddde78648..66a5fdd3694d 100644
--- a/colossalai/utils/data_sampler/data_parallel_sampler.py
+++ b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py
@@ -10,8 +10,8 @@
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
T_co = TypeVar('T_co', covariant=True)
diff --git a/colossalai/utils/memory.py b/colossalai/legacy/utils/memory.py
similarity index 95%
rename from colossalai/utils/memory.py
rename to colossalai/legacy/utils/memory.py
index 434e90edd3b9..360bf0da4a77 100644
--- a/colossalai/utils/memory.py
+++ b/colossalai/legacy/utils/memory.py
@@ -1,15 +1,15 @@
-import torch
import gc
-import psutil
from collections import namedtuple
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.utils import get_current_device
-from colossalai.core import global_context as gpc
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.logging import get_dist_logger
+import psutil
+import torch
+import torch.distributed as dist
from packaging import version
+from colossalai.legacy.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.utils import get_current_device
+
_GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1
@@ -68,7 +68,7 @@ def report_memory_usage(message, logger=None, report_cpu=False):
Raises:
EnvironmentError: Raise error if no distributed environment has been initialized.
"""
- if not gpc.is_initialized(ParallelMode.GLOBAL):
+ if not dist.is_initialized():
raise EnvironmentError("No distributed environment is initialized")
gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated())
@@ -138,7 +138,7 @@ def colo_device_memory_used(device: torch.device) -> int:
def colo_set_process_memory_fraction(ratio: float) -> None:
- """colo_set_process_memory_fraction
+ """colo_set_process_memory_fraction
set how much cuda memory used on the gpu belonging to the current process.
diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/legacy/utils/profiler/__init__.py
similarity index 100%
rename from colossalai/utils/profiler/__init__.py
rename to colossalai/legacy/utils/profiler/__init__.py
diff --git a/colossalai/utils/profiler/extention.py b/colossalai/legacy/utils/profiler/extention.py
similarity index 100%
rename from colossalai/utils/profiler/extention.py
rename to colossalai/legacy/utils/profiler/extention.py
diff --git a/colossalai/utils/profiler/legacy/__init__.py b/colossalai/legacy/utils/profiler/legacy/__init__.py
similarity index 77%
rename from colossalai/utils/profiler/legacy/__init__.py
rename to colossalai/legacy/utils/profiler/legacy/__init__.py
index 849c7fca3053..88beed86d7de 100644
--- a/colossalai/utils/profiler/legacy/__init__.py
+++ b/colossalai/legacy/utils/profiler/legacy/__init__.py
@@ -1,6 +1,6 @@
-from .comm_profiler import CommProfiler
-from .pcie_profiler import PcieProfiler
-from .prof_utils import ProfilerContext, BaseProfiler
-from .mem_profiler import MemProfiler
-
-__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
+from .comm_profiler import CommProfiler
+from .mem_profiler import MemProfiler
+from .pcie_profiler import PcieProfiler
+from .prof_utils import BaseProfiler, ProfilerContext
+
+__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
diff --git a/colossalai/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py
similarity index 96%
rename from colossalai/utils/profiler/legacy/comm_profiler.py
rename to colossalai/legacy/utils/profiler/legacy/comm_profiler.py
index 334f0113ee90..bb7e2654c740 100644
--- a/colossalai/utils/profiler/legacy/comm_profiler.py
+++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py
@@ -1,308 +1,311 @@
-import inspect
-from pathlib import Path
-from functools import partial
-import torch
-from torch.autograd.profiler import profile
-import torch.distributed as dist
-from torch.distributed import ReduceOp
-from colossalai.utils import get_current_device
-from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth
-from typing import List, Optional
-
-
-def _get_code_location(depth: int):
- ret = []
- length = min(len(inspect.stack()), depth + 1)
- for i in range(3, length):
- upper_frame = inspect.stack()[i]
- function_name = inspect.stack()[i - 1].function
- ret.append(upper_frame.filename)
- ret.append('(')
- ret.append(str(upper_frame.lineno))
- ret.append('): ')
- ret.append(function_name)
- if i != length - 1:
- ret.append('\n')
-
- return ''.join(ret)
-
-
-torch_all_reduce = dist.all_reduce
-torch_all_gather = dist.all_gather
-torch_reduce_scatter = dist.reduce_scatter
-torch_broadcast = dist.broadcast
-torch_reduce = dist.reduce
-
-
-class CommEvent(object):
- """Communication Event. Used for communication time and communication
- volume recording.
- """
-
- def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
- self.self_count = count
- self.self_comm_vol = comm_vol
- self.self_cuda_time = cuda_time
-
- def add(self, rhs):
- self.self_count += rhs.self_count
- self.self_comm_vol += rhs.self_comm_vol
- self.self_cuda_time += rhs.self_cuda_time
-
-
-class CommProfiler(BaseProfiler):
- """Communication profiler. Records all communication events.
- """
-
- def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
- super().__init__(profiler_name="Collective_Communication", priority=0)
- self.depth = 3 + depth
- self.total_count = total_count
- self.total_comm_vol = total_comm_vol
- self.total_cuda_time = total_cuda_time
-
- self.ops_record = dict()
- self.profiler = None
- self.pending_op = None
- self.pending_metadata = None
- self.warn_flag = False
-
- def reset(self):
- self.total_count = 0
- self.total_comm_vol = 0
- self.total_cuda_time = 0
-
- self.ops_record = dict()
- self.profiler = None
- self.pending_op = None
- self.pending_metadata = None
- self.warn_flag = False
-
- def enable(self):
- dist.all_reduce = partial(all_reduce, profiler=self)
- dist.all_gather = partial(all_gather, profiler=self)
- dist.reduce_scatter = partial(reduce_scatter, profiler=self)
- dist.broadcast = partial(broadcast, profiler=self)
- dist.reduce = partial(reduce, profiler=self)
-
- def disable(self):
- dist.all_reduce = torch_all_reduce
- dist.all_gather = torch_all_gather
- dist.reduce_scatter = torch_reduce_scatter
- dist.broadcast = torch_broadcast
- dist.reduce = torch_reduce
-
- def to_tensorboard(self, writer):
- writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n"))
-
- def to_file(self, filename: Path):
- with open(filename, "w") as f:
- f.write(self.result_str())
-
- def show(self):
- print(self.result_str())
-
- def result_str(self, sep: str = "\n"):
- res = []
-
- def append(s: str = None):
- if s is not None:
- res.append(s)
- res.append(sep)
-
- if self.warn_flag:
- append("Warning: there exists multiple communication operations in the same time. As a result, "
- "the profiling result is not accurate.")
-
- if self.total_cuda_time == 0:
- return "No collective communication has been called yet!"
-
- append("Collective communication profiling result:")
- append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
- append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time)))
- append("total number of calls: {}".format(self.total_count))
- append("All events:")
-
- separation = '-' * 74
- row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
-
- append(separation)
- append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
- append(separation)
-
- show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
- for location, event in show_list:
- append(location)
- append(
- row_format.format('', _format_time(event.self_cuda_time),
- '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
- _format_memory(event.self_comm_vol),
- _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
- append()
-
- return ''.join(res)
-
- @property
- def has_aync_op(self):
- return self.pending_op is not None
-
- def activate_profiler(self, kn: str, vol: float):
- self.pending_metadata = (kn, _get_code_location(self.depth), vol)
- self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True)
- self.profiler.__enter__()
-
- def close_profiler(self, group=None):
- assert self.profiler is not None, "There is no running dist op"
- kernel_name, code_location, vol = self.pending_metadata
- self.profiler.__exit__(None, None, None)
-
- if self.profiler.enabled and dist.get_world_size(group) > 1:
- assert_flag = 0
- current_comm_event = None
- events = self.profiler.function_events
- for event in events:
- if kernel_name in event.name:
- assert assert_flag == 0, "Multiple dist ops has been called "
- current_comm_event = CommEvent(1, vol, event.self_cuda_time_total)
- assert_flag += 1
-
- assert current_comm_event is not None, "dist op has not been found"
-
- buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
- torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
- current_comm_event.self_cuda_time = buffer.item()
-
- self.total_count += current_comm_event.self_count
- self.total_comm_vol += current_comm_event.self_comm_vol
- self.total_cuda_time += current_comm_event.self_cuda_time
- if code_location in self.ops_record:
- self.ops_record[code_location].add(current_comm_event)
- else:
- self.ops_record[code_location] = current_comm_event
-
- self.profiler = None
- self.pending_op = None
- self.pending_metadata = None
-
- def wait_async_op(self):
- if self.pending_op is not None:
- op = self.pending_op
- op.wait()
- self.close_profiler()
-
-
-class CommHandler(object):
- """Communication handler. A dummy handler to wait aync operations.
- """
-
- def __init__(self, profiler: CommProfiler):
- super().__init__()
- self.prof = profiler
-
- def wait(self):
- self.prof.wait_async_op()
-
-
-def async_check(profiler: CommProfiler):
- if profiler.pending_op is not None:
- profiler.warn_flag = True
- profiler.wait_async_op()
-
-
-def all_reduce(tensor: torch.Tensor,
- op: ReduceOp = ReduceOp.SUM,
- group=None,
- async_op: bool = False,
- profiler: CommProfiler = None) -> Optional[CommHandler]:
- async_check(profiler)
-
- comm_size = dist.get_world_size(group)
- correction = 2 * (comm_size - 1) / comm_size
- comm_vol = correction * tensor.element_size() * tensor.numel()
- profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
- profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
-
- if async_op:
- return CommHandler(profiler)
-
- profiler.close_profiler(group)
-
-
-def reduce_scatter(output: torch.Tensor,
- input_list: List[torch.Tensor],
- op: ReduceOp = ReduceOp.SUM,
- group=None,
- async_op: bool = False,
- profiler: CommProfiler = None) -> Optional[CommHandler]:
- async_check(profiler)
-
- comm_size = dist.get_world_size(group)
- correction = (comm_size - 1) / comm_size
- comm_vol = 0
- for tensor in input_list:
- comm_vol += tensor.element_size() * tensor.numel()
- comm_vol *= correction
- profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
- profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
-
- if async_op:
- return CommHandler(profiler)
-
- profiler.close_profiler(group)
-
-
-def all_gather(tensor_list: List[torch.Tensor],
- tensor: torch.Tensor,
- group=None,
- async_op: bool = False,
- profiler: CommProfiler = None) -> Optional[CommHandler]:
- async_check(profiler)
-
- comm_size = dist.get_world_size(group)
- correction = (comm_size - 1) / comm_size
- comm_vol = 0
- for ten in tensor_list:
- comm_vol += ten.element_size() * ten.numel()
- comm_vol *= correction
- profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
- profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
-
- if async_op:
- return CommHandler(profiler)
-
- profiler.close_profiler(group)
-
-
-def broadcast(tensor: torch.Tensor,
- src: int,
- group=None,
- async_op: bool = False,
- profiler: CommProfiler = None) -> Optional[CommHandler]:
- async_check(profiler)
-
- comm_vol = 1.0 * tensor.element_size() * tensor.numel()
- profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
- profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
-
- if async_op:
- return CommHandler(profiler)
-
- profiler.close_profiler(group)
-
-
-def reduce(tensor: torch.Tensor,
- dst: int,
- op: ReduceOp = ReduceOp.SUM,
- group=None,
- async_op: bool = False,
- profiler: CommProfiler = None) -> Optional[CommHandler]:
- async_check(profiler)
-
- comm_vol = 1.0 * tensor.element_size() * tensor.numel()
- profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
- profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
-
- if async_op:
- return CommHandler(profiler)
-
- profiler.close_profiler(group)
+import inspect
+from functools import partial
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+from torch.autograd.profiler import profile
+from torch.distributed import ReduceOp
+
+from colossalai.utils import get_current_device
+
+from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
+
+
+def _get_code_location(depth: int):
+ ret = []
+ length = min(len(inspect.stack()), depth + 1)
+ for i in range(3, length):
+ upper_frame = inspect.stack()[i]
+ function_name = inspect.stack()[i - 1].function
+ ret.append(upper_frame.filename)
+ ret.append('(')
+ ret.append(str(upper_frame.lineno))
+ ret.append('): ')
+ ret.append(function_name)
+ if i != length - 1:
+ ret.append('\n')
+
+ return ''.join(ret)
+
+
+torch_all_reduce = dist.all_reduce
+torch_all_gather = dist.all_gather
+torch_reduce_scatter = dist.reduce_scatter
+torch_broadcast = dist.broadcast
+torch_reduce = dist.reduce
+
+
+class CommEvent(object):
+ """Communication Event. Used for communication time and communication
+ volume recording.
+ """
+
+ def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0):
+ self.self_count = count
+ self.self_comm_vol = comm_vol
+ self.self_cuda_time = cuda_time
+
+ def add(self, rhs):
+ self.self_count += rhs.self_count
+ self.self_comm_vol += rhs.self_comm_vol
+ self.self_cuda_time += rhs.self_cuda_time
+
+
+class CommProfiler(BaseProfiler):
+ """Communication profiler. Records all communication events.
+ """
+
+ def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0):
+ super().__init__(profiler_name="Collective_Communication", priority=0)
+ self.depth = 3 + depth
+ self.total_count = total_count
+ self.total_comm_vol = total_comm_vol
+ self.total_cuda_time = total_cuda_time
+
+ self.ops_record = dict()
+ self.profiler = None
+ self.pending_op = None
+ self.pending_metadata = None
+ self.warn_flag = False
+
+ def reset(self):
+ self.total_count = 0
+ self.total_comm_vol = 0
+ self.total_cuda_time = 0
+
+ self.ops_record = dict()
+ self.profiler = None
+ self.pending_op = None
+ self.pending_metadata = None
+ self.warn_flag = False
+
+ def enable(self):
+ dist.all_reduce = partial(all_reduce, profiler=self)
+ dist.all_gather = partial(all_gather, profiler=self)
+ dist.reduce_scatter = partial(reduce_scatter, profiler=self)
+ dist.broadcast = partial(broadcast, profiler=self)
+ dist.reduce = partial(reduce, profiler=self)
+
+ def disable(self):
+ dist.all_reduce = torch_all_reduce
+ dist.all_gather = torch_all_gather
+ dist.reduce_scatter = torch_reduce_scatter
+ dist.broadcast = torch_broadcast
+ dist.reduce = torch_reduce
+
+ def to_tensorboard(self, writer):
+ writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n"))
+
+ def to_file(self, filename: Path):
+ with open(filename, "w") as f:
+ f.write(self.result_str())
+
+ def show(self):
+ print(self.result_str())
+
+ def result_str(self, sep: str = "\n"):
+ res = []
+
+ def append(s: str = None):
+ if s is not None:
+ res.append(s)
+ res.append(sep)
+
+ if self.warn_flag:
+ append("Warning: there exists multiple communication operations in the same time. As a result, "
+ "the profiling result is not accurate.")
+
+ if self.total_cuda_time == 0:
+ return "No collective communication has been called yet!"
+
+ append("Collective communication profiling result:")
+ append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
+ append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time)))
+ append("total number of calls: {}".format(self.total_count))
+ append("All events:")
+
+ separation = '-' * 74
+ row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
+
+ append(separation)
+ append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
+ append(separation)
+
+ show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
+ for location, event in show_list:
+ append(location)
+ append(
+ row_format.format('', _format_time(event.self_cuda_time),
+ '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0),
+ _format_memory(event.self_comm_vol),
+ _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count))
+ append()
+
+ return ''.join(res)
+
+ @property
+ def has_aync_op(self):
+ return self.pending_op is not None
+
+ def activate_profiler(self, kn: str, vol: float):
+ self.pending_metadata = (kn, _get_code_location(self.depth), vol)
+ self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True)
+ self.profiler.__enter__()
+
+ def close_profiler(self, group=None):
+ assert self.profiler is not None, "There is no running dist op"
+ kernel_name, code_location, vol = self.pending_metadata
+ self.profiler.__exit__(None, None, None)
+
+ if self.profiler.enabled and dist.get_world_size(group) > 1:
+ assert_flag = 0
+ current_comm_event = None
+ events = self.profiler.function_events
+ for event in events:
+ if kernel_name in event.name:
+ assert assert_flag == 0, "Multiple dist ops has been called "
+ current_comm_event = CommEvent(1, vol, event.self_cuda_time_total)
+ assert_flag += 1
+
+ assert current_comm_event is not None, "dist op has not been found"
+
+ buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
+ torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
+ current_comm_event.self_cuda_time = buffer.item()
+
+ self.total_count += current_comm_event.self_count
+ self.total_comm_vol += current_comm_event.self_comm_vol
+ self.total_cuda_time += current_comm_event.self_cuda_time
+ if code_location in self.ops_record:
+ self.ops_record[code_location].add(current_comm_event)
+ else:
+ self.ops_record[code_location] = current_comm_event
+
+ self.profiler = None
+ self.pending_op = None
+ self.pending_metadata = None
+
+ def wait_async_op(self):
+ if self.pending_op is not None:
+ op = self.pending_op
+ op.wait()
+ self.close_profiler()
+
+
+class CommHandler(object):
+ """Communication handler. A dummy handler to wait aync operations.
+ """
+
+ def __init__(self, profiler: CommProfiler):
+ super().__init__()
+ self.prof = profiler
+
+ def wait(self):
+ self.prof.wait_async_op()
+
+
+def async_check(profiler: CommProfiler):
+ if profiler.pending_op is not None:
+ profiler.warn_flag = True
+ profiler.wait_async_op()
+
+
+def all_reduce(tensor: torch.Tensor,
+ op: ReduceOp = ReduceOp.SUM,
+ group=None,
+ async_op: bool = False,
+ profiler: CommProfiler = None) -> Optional[CommHandler]:
+ async_check(profiler)
+
+ comm_size = dist.get_world_size(group)
+ correction = 2 * (comm_size - 1) / comm_size
+ comm_vol = correction * tensor.element_size() * tensor.numel()
+ profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol)
+ profiler.pending_op = torch_all_reduce(tensor, op, group, async_op)
+
+ if async_op:
+ return CommHandler(profiler)
+
+ profiler.close_profiler(group)
+
+
+def reduce_scatter(output: torch.Tensor,
+ input_list: List[torch.Tensor],
+ op: ReduceOp = ReduceOp.SUM,
+ group=None,
+ async_op: bool = False,
+ profiler: CommProfiler = None) -> Optional[CommHandler]:
+ async_check(profiler)
+
+ comm_size = dist.get_world_size(group)
+ correction = (comm_size - 1) / comm_size
+ comm_vol = 0
+ for tensor in input_list:
+ comm_vol += tensor.element_size() * tensor.numel()
+ comm_vol *= correction
+ profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol)
+ profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op)
+
+ if async_op:
+ return CommHandler(profiler)
+
+ profiler.close_profiler(group)
+
+
+def all_gather(tensor_list: List[torch.Tensor],
+ tensor: torch.Tensor,
+ group=None,
+ async_op: bool = False,
+ profiler: CommProfiler = None) -> Optional[CommHandler]:
+ async_check(profiler)
+
+ comm_size = dist.get_world_size(group)
+ correction = (comm_size - 1) / comm_size
+ comm_vol = 0
+ for ten in tensor_list:
+ comm_vol += ten.element_size() * ten.numel()
+ comm_vol *= correction
+ profiler.activate_profiler("ncclKernel_AllGather_", comm_vol)
+ profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op)
+
+ if async_op:
+ return CommHandler(profiler)
+
+ profiler.close_profiler(group)
+
+
+def broadcast(tensor: torch.Tensor,
+ src: int,
+ group=None,
+ async_op: bool = False,
+ profiler: CommProfiler = None) -> Optional[CommHandler]:
+ async_check(profiler)
+
+ comm_vol = 1.0 * tensor.element_size() * tensor.numel()
+ profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol)
+ profiler.pending_op = torch_broadcast(tensor, src, group, async_op)
+
+ if async_op:
+ return CommHandler(profiler)
+
+ profiler.close_profiler(group)
+
+
+def reduce(tensor: torch.Tensor,
+ dst: int,
+ op: ReduceOp = ReduceOp.SUM,
+ group=None,
+ async_op: bool = False,
+ profiler: CommProfiler = None) -> Optional[CommHandler]:
+ async_check(profiler)
+
+ comm_vol = 1.0 * tensor.element_size() * tensor.numel()
+ profiler.activate_profiler("ncclKernel_Reduce_", comm_vol)
+ profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op)
+
+ if async_op:
+ return CommHandler(profiler)
+
+ profiler.close_profiler(group)
diff --git a/colossalai/utils/profiler/legacy/pcie_profiler.py b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py
similarity index 95%
rename from colossalai/utils/profiler/legacy/pcie_profiler.py
rename to colossalai/legacy/utils/profiler/legacy/pcie_profiler.py
index 8f812f5cfc7b..514d3c6fabfa 100644
--- a/colossalai/utils/profiler/legacy/pcie_profiler.py
+++ b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py
@@ -1,148 +1,150 @@
-from pathlib import Path
-from torch.autograd.profiler import profile
-from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth
-from typing import List
-
-
-def _get_size(dtype: str):
- if dtype == "fp16":
- return 2
- elif dtype == "fp32":
- return 4
- else:
- raise NotImplementedError
-
-
-def _get_numel(my_list: List[int]) -> int:
- from functools import reduce
- from operator import mul
- return reduce(mul, my_list)
-
-
-def _reduce_location(locations: List[str]) -> str:
- ret = []
- for lo in locations:
- ret.append(lo)
- ret.append("\n")
- ret = ret[:-1]
- return ''.join(ret)
-
-
-class PcieEvent(object):
- """Pcie Event.
- """
-
- def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0):
- self.count = count
- self.pcie_vol = pcie_vol
- self.cuda_time = cuda_time
-
- def add(self, rhs):
- self.count += rhs.count
- self.pcie_vol += rhs.pcie_vol
- self.cuda_time += rhs.cuda_time
-
-
-class PcieProfiler(BaseProfiler):
- """Pcie profiler. Records all data transmission between CPU and GPU.
-
- TODO: Merge pcie profiler into communication profiler
- """
-
- def __init__(self, dtype: str = "fp32", depth: int = 1):
- super().__init__(profiler_name="Pcie", priority=10)
- self.depth = depth
- self.data_size = _get_size(dtype)
- self.h2d_count = 0
- self.h2d_time = 0
- self.d2h_count = 0
- self.d2h_time = 0
-
- self.ops_record = dict()
- self.profiler = None
-
- def reset(self):
- self.h2d_count = 0
- self.h2d_time = 0
- self.d2h_count = 0
- self.d2h_time = 0
-
- self.ops_record = dict()
- self.profiler = None
-
- def enable(self):
- self.profiler = profile(enabled=True,
- use_cuda=True,
- use_cpu=True,
- use_kineto=True,
- record_shapes=True,
- with_stack=True)
- self.profiler.__enter__()
-
- def disable(self):
- self.profiler.__exit__(None, None, None)
-
- if self.profiler.enabled:
- events = self.profiler.function_events
- for event in events:
- if event.name == "aten::copy_":
- t_shape = event.input_shapes[0]
- if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
- continue
- current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
- code_location = _reduce_location(event.stack[:self.depth])
- if code_location in self.ops_record:
- self.ops_record[code_location].add(current_comm_event)
- else:
- self.ops_record[code_location] = current_comm_event
- elif 'Memcpy HtoD' in event.name:
- self.h2d_count += 1
- self.h2d_time += event.cuda_time_total
- elif 'Memcpy DtoH' in event.name:
- self.d2h_count += 1
- self.d2h_time += event.cuda_time_total
-
- self.profiler = None
-
- def to_tensorboard(self, writer):
- writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n"))
-
- def to_file(self, filename: Path):
- with open(filename, "w") as f:
- f.write(self.result_str())
-
- def show(self):
- print(self.result_str())
-
- def result_str(self, sep: str = "\n"):
- res = []
-
- def append(s: str = None):
- if s is not None:
- res.append(s)
- res.append(sep)
-
- append("Pcie profiling result:")
- append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time)))
- append("number of transmission (CPU -> GPU): {}".format(self.h2d_count))
- append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time)))
- append("number of transmission (GPU -> CPU): {}".format(self.d2h_count))
-
- append("Possible data transmission events in PCIE:")
-
- separation = '-' * 62
- row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
-
- append(separation)
- append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
- append(separation)
-
- show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
- for location, event in show_list:
- append(location)
- append(
- row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
- _format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
- append()
-
- return ''.join(res)
+from pathlib import Path
+from typing import List
+
+from torch.autograd.profiler import profile
+
+from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
+
+
+def _get_size(dtype: str):
+ if dtype == "fp16":
+ return 2
+ elif dtype == "fp32":
+ return 4
+ else:
+ raise NotImplementedError
+
+
+def _get_numel(my_list: List[int]) -> int:
+ from functools import reduce
+ from operator import mul
+ return reduce(mul, my_list)
+
+
+def _reduce_location(locations: List[str]) -> str:
+ ret = []
+ for lo in locations:
+ ret.append(lo)
+ ret.append("\n")
+ ret = ret[:-1]
+ return ''.join(ret)
+
+
+class PcieEvent(object):
+ """Pcie Event.
+ """
+
+ def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0):
+ self.count = count
+ self.pcie_vol = pcie_vol
+ self.cuda_time = cuda_time
+
+ def add(self, rhs):
+ self.count += rhs.count
+ self.pcie_vol += rhs.pcie_vol
+ self.cuda_time += rhs.cuda_time
+
+
+class PcieProfiler(BaseProfiler):
+ """Pcie profiler. Records all data transmission between CPU and GPU.
+
+ TODO: Merge pcie profiler into communication profiler
+ """
+
+ def __init__(self, dtype: str = "fp32", depth: int = 1):
+ super().__init__(profiler_name="Pcie", priority=10)
+ self.depth = depth
+ self.data_size = _get_size(dtype)
+ self.h2d_count = 0
+ self.h2d_time = 0
+ self.d2h_count = 0
+ self.d2h_time = 0
+
+ self.ops_record = dict()
+ self.profiler = None
+
+ def reset(self):
+ self.h2d_count = 0
+ self.h2d_time = 0
+ self.d2h_count = 0
+ self.d2h_time = 0
+
+ self.ops_record = dict()
+ self.profiler = None
+
+ def enable(self):
+ self.profiler = profile(enabled=True,
+ use_cuda=True,
+ use_cpu=True,
+ use_kineto=True,
+ record_shapes=True,
+ with_stack=True)
+ self.profiler.__enter__()
+
+ def disable(self):
+ self.profiler.__exit__(None, None, None)
+
+ if self.profiler.enabled:
+ events = self.profiler.function_events
+ for event in events:
+ if event.name == "aten::copy_":
+ t_shape = event.input_shapes[0]
+ if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0:
+ continue
+ current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total)
+ code_location = _reduce_location(event.stack[:self.depth])
+ if code_location in self.ops_record:
+ self.ops_record[code_location].add(current_comm_event)
+ else:
+ self.ops_record[code_location] = current_comm_event
+ elif 'Memcpy HtoD' in event.name:
+ self.h2d_count += 1
+ self.h2d_time += event.cuda_time_total
+ elif 'Memcpy DtoH' in event.name:
+ self.d2h_count += 1
+ self.d2h_time += event.cuda_time_total
+
+ self.profiler = None
+
+ def to_tensorboard(self, writer):
+ writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n"))
+
+ def to_file(self, filename: Path):
+ with open(filename, "w") as f:
+ f.write(self.result_str())
+
+ def show(self):
+ print(self.result_str())
+
+ def result_str(self, sep: str = "\n"):
+ res = []
+
+ def append(s: str = None):
+ if s is not None:
+ res.append(s)
+ res.append(sep)
+
+ append("Pcie profiling result:")
+ append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time)))
+ append("number of transmission (CPU -> GPU): {}".format(self.h2d_count))
+ append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time)))
+ append("number of transmission (GPU -> CPU): {}".format(self.d2h_count))
+
+ append("Possible data transmission events in PCIE:")
+
+ separation = '-' * 62
+ row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
+
+ append(separation)
+ append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
+ append(separation)
+
+ show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
+ for location, event in show_list:
+ append(location)
+ append(
+ row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol),
+ _format_bandwidth(event.pcie_vol, event.cuda_time), event.count))
+ append()
+
+ return ''.join(res)
diff --git a/colossalai/utils/profiler/legacy/prof_utils.py b/colossalai/legacy/utils/profiler/legacy/prof_utils.py
similarity index 94%
rename from colossalai/utils/profiler/legacy/prof_utils.py
rename to colossalai/legacy/utils/profiler/legacy/prof_utils.py
index 2f7eee827651..9b948c9ec1cd 100644
--- a/colossalai/utils/profiler/legacy/prof_utils.py
+++ b/colossalai/legacy/utils/profiler/legacy/prof_utils.py
@@ -1,131 +1,132 @@
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Union, List
-from colossalai.core import global_context as gpc
-
-
-# copied from high version pytorch to support low version
-def _format_time(time_us):
- """Defines how to format time in FunctionEvent"""
- US_IN_SECOND = 1000.0 * 1000.0
- US_IN_MS = 1000.0
- if time_us >= US_IN_SECOND:
- return '{:.3f}s'.format(time_us / US_IN_SECOND)
- if time_us >= US_IN_MS:
- return '{:.3f}ms'.format(time_us / US_IN_MS)
- return '{:.3f}us'.format(time_us)
-
-
-# copied from high version pytorch to support low version
-def _format_memory(nbytes):
- """Returns a formatted memory size string"""
- KB = 1024
- MB = 1024 * KB
- GB = 1024 * MB
- if (abs(nbytes) >= GB):
- return '{:.2f} GB'.format(nbytes * 1.0 / GB)
- elif (abs(nbytes) >= MB):
- return '{:.2f} MB'.format(nbytes * 1.0 / MB)
- elif (abs(nbytes) >= KB):
- return '{:.2f} KB'.format(nbytes * 1.0 / KB)
- else:
- return str(nbytes) + ' B'
-
-
-def _format_bandwidth(volume: float or int, time_us: int):
- sec_div_mb = (1000.0 / 1024.0)**2
- mb_per_sec = volume / time_us * sec_div_mb
-
- if mb_per_sec >= 1024.0:
- return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
- else:
- return '{:.3f} MB/s'.format(mb_per_sec)
-
-
-class BaseProfiler(ABC):
-
- def __init__(self, profiler_name: str, priority: int):
- self.name = profiler_name
- self.priority = priority
-
- @abstractmethod
- def enable(self):
- pass
-
- @abstractmethod
- def disable(self):
- pass
-
- @abstractmethod
- def to_tensorboard(self, writer):
- pass
-
- @abstractmethod
- def to_file(self, filename: Path):
- pass
-
- @abstractmethod
- def show(self):
- pass
-
-
-class ProfilerContext(object):
- """Profiler context manager
-
- Usage::
-
- world_size = 4
- inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
- outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
- outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
-
- cc_prof = CommProfiler()
-
- with ProfilerContext([cc_prof]) as prof:
- op = dist.all_reduce(inputs, async_op=True)
- dist.all_gather(outputs_list, inputs)
- op.wait()
- dist.reduce_scatter(inputs, outputs_list)
- dist.broadcast(inputs, 0)
- dist.reduce(inputs, 0)
-
- prof.show()
- """
-
- def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
- self.enable = enable
- self.profilers = sorted(profilers, key=lambda prof: prof.priority)
-
- def __enter__(self):
- if self.enable:
- for prof in self.profilers:
- prof.enable()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.enable:
- for prof in self.profilers:
- prof.disable()
-
- def to_tensorboard(self, writer):
- from torch.utils.tensorboard import SummaryWriter
-
- assert isinstance(writer, SummaryWriter), \
- f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
-
- for prof in self.profilers:
- prof.to_tensorboard(writer)
-
- def to_file(self, log_dir: Union[str, Path]):
- if isinstance(log_dir, str):
- log_dir = Path(log_dir)
-
- if not log_dir.exists():
- log_dir.mkdir(parents=True, exist_ok=True)
- for prof in self.profilers:
- log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
- prof.to_file(log_file)
-
- def show(self):
- for prof in self.profilers:
- prof.show()
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import List, Union
+
+from colossalai.legacy.core import global_context as gpc
+
+
+# copied from high version pytorch to support low version
+def _format_time(time_us):
+ """Defines how to format time in FunctionEvent"""
+ US_IN_SECOND = 1000.0 * 1000.0
+ US_IN_MS = 1000.0
+ if time_us >= US_IN_SECOND:
+ return '{:.3f}s'.format(time_us / US_IN_SECOND)
+ if time_us >= US_IN_MS:
+ return '{:.3f}ms'.format(time_us / US_IN_MS)
+ return '{:.3f}us'.format(time_us)
+
+
+# copied from high version pytorch to support low version
+def _format_memory(nbytes):
+ """Returns a formatted memory size string"""
+ KB = 1024
+ MB = 1024 * KB
+ GB = 1024 * MB
+ if (abs(nbytes) >= GB):
+ return '{:.2f} GB'.format(nbytes * 1.0 / GB)
+ elif (abs(nbytes) >= MB):
+ return '{:.2f} MB'.format(nbytes * 1.0 / MB)
+ elif (abs(nbytes) >= KB):
+ return '{:.2f} KB'.format(nbytes * 1.0 / KB)
+ else:
+ return str(nbytes) + ' B'
+
+
+def _format_bandwidth(volume: float or int, time_us: int):
+ sec_div_mb = (1000.0 / 1024.0)**2
+ mb_per_sec = volume / time_us * sec_div_mb
+
+ if mb_per_sec >= 1024.0:
+ return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
+ else:
+ return '{:.3f} MB/s'.format(mb_per_sec)
+
+
+class BaseProfiler(ABC):
+
+ def __init__(self, profiler_name: str, priority: int):
+ self.name = profiler_name
+ self.priority = priority
+
+ @abstractmethod
+ def enable(self):
+ pass
+
+ @abstractmethod
+ def disable(self):
+ pass
+
+ @abstractmethod
+ def to_tensorboard(self, writer):
+ pass
+
+ @abstractmethod
+ def to_file(self, filename: Path):
+ pass
+
+ @abstractmethod
+ def show(self):
+ pass
+
+
+class ProfilerContext(object):
+ """Profiler context manager
+
+ Usage::
+
+ world_size = 4
+ inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
+ outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
+ outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
+
+ cc_prof = CommProfiler()
+
+ with ProfilerContext([cc_prof]) as prof:
+ op = dist.all_reduce(inputs, async_op=True)
+ dist.all_gather(outputs_list, inputs)
+ op.wait()
+ dist.reduce_scatter(inputs, outputs_list)
+ dist.broadcast(inputs, 0)
+ dist.reduce(inputs, 0)
+
+ prof.show()
+ """
+
+ def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
+ self.enable = enable
+ self.profilers = sorted(profilers, key=lambda prof: prof.priority)
+
+ def __enter__(self):
+ if self.enable:
+ for prof in self.profilers:
+ prof.enable()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.enable:
+ for prof in self.profilers:
+ prof.disable()
+
+ def to_tensorboard(self, writer):
+ from torch.utils.tensorboard import SummaryWriter
+
+ assert isinstance(writer, SummaryWriter), \
+ f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.'
+
+ for prof in self.profilers:
+ prof.to_tensorboard(writer)
+
+ def to_file(self, log_dir: Union[str, Path]):
+ if isinstance(log_dir, str):
+ log_dir = Path(log_dir)
+
+ if not log_dir.exists():
+ log_dir.mkdir(parents=True, exist_ok=True)
+ for prof in self.profilers:
+ log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log')
+ prof.to_file(log_file)
+
+ def show(self):
+ for prof in self.profilers:
+ prof.show()
diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/legacy/utils/profiler/profiler.py
similarity index 97%
rename from colossalai/utils/profiler/profiler.py
rename to colossalai/legacy/utils/profiler/profiler.py
index 3026d723deb0..0827f06b586c 100644
--- a/colossalai/utils/profiler/profiler.py
+++ b/colossalai/legacy/utils/profiler/profiler.py
@@ -9,9 +9,9 @@
from torch.profiler.profiler import ProfilerAction
from colossalai.legacy.engine import Engine
+from colossalai.legacy.utils.profiler.extention import ProfilerExtension
+from colossalai.legacy.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
from colossalai.logging import get_dist_logger
-from colossalai.utils.profiler.extention import ProfilerExtension
-from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
class profile(torch_profile):
diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py
similarity index 98%
rename from colossalai/utils/profiler/stateful_tensor_mem_extention.py
rename to colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py
index 412bd7277eee..f3bb66ced583 100644
--- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py
+++ b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py
@@ -9,7 +9,7 @@
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.legacy.engine import Engine
-from colossalai.utils.profiler.extention import ProfilerExtension
+from colossalai.legacy.utils.profiler.extention import ProfilerExtension
class DeviceType(Enum):
diff --git a/colossalai/zero/legacy/__init__.py b/colossalai/legacy/zero/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/__init__.py
rename to colossalai/legacy/zero/__init__.py
diff --git a/colossalai/zero/legacy/gemini/__init__.py b/colossalai/legacy/zero/gemini/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/__init__.py
rename to colossalai/legacy/zero/gemini/__init__.py
diff --git a/colossalai/zero/legacy/gemini/gemini_context.py b/colossalai/legacy/zero/gemini/gemini_context.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/gemini_context.py
rename to colossalai/legacy/zero/gemini/gemini_context.py
diff --git a/colossalai/zero/legacy/gemini/ophooks/__init__.py b/colossalai/legacy/zero/gemini/ophooks/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/ophooks/__init__.py
rename to colossalai/legacy/zero/gemini/ophooks/__init__.py
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
rename to colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
rename to colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py
diff --git a/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py
similarity index 98%
rename from colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py
rename to colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py
index f40d6ced1ee0..eebcf86e0e58 100644
--- a/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py
+++ b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py
@@ -5,9 +5,9 @@
import torch
+from colossalai.legacy.zero.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
-from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class TrainingPhase(Enum):
diff --git a/colossalai/zero/legacy/gemini/ophooks/utils.py b/colossalai/legacy/zero/gemini/ophooks/utils.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/ophooks/utils.py
rename to colossalai/legacy/zero/gemini/ophooks/utils.py
diff --git a/colossalai/zero/legacy/gemini/paramhooks/__init__.py b/colossalai/legacy/zero/gemini/paramhooks/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/paramhooks/__init__.py
rename to colossalai/legacy/zero/gemini/paramhooks/__init__.py
diff --git a/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py
rename to colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py
diff --git a/colossalai/zero/legacy/gemini/stateful_tensor.py b/colossalai/legacy/zero/gemini/stateful_tensor.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/stateful_tensor.py
rename to colossalai/legacy/zero/gemini/stateful_tensor.py
diff --git a/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/stateful_tensor_mgr.py
rename to colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
diff --git a/colossalai/zero/legacy/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py
similarity index 98%
rename from colossalai/zero/legacy/gemini/tensor_placement_policy.py
rename to colossalai/legacy/zero/gemini/tensor_placement_policy.py
index 165ae51fee60..275933ec2cfb 100644
--- a/colossalai/zero/legacy/gemini/tensor_placement_policy.py
+++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py
@@ -5,8 +5,8 @@
import torch
+from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
-from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
diff --git a/colossalai/zero/legacy/gemini/tensor_utils.py b/colossalai/legacy/zero/gemini/tensor_utils.py
similarity index 100%
rename from colossalai/zero/legacy/gemini/tensor_utils.py
rename to colossalai/legacy/zero/gemini/tensor_utils.py
diff --git a/colossalai/zero/legacy/init_ctx/__init__.py b/colossalai/legacy/zero/init_ctx/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/init_ctx/__init__.py
rename to colossalai/legacy/zero/init_ctx/__init__.py
diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/legacy/zero/init_ctx/init_context.py
similarity index 96%
rename from colossalai/zero/legacy/init_ctx/init_context.py
rename to colossalai/legacy/zero/init_ctx/init_context.py
index 84e2d2f4f8e1..4a7e46408583 100644
--- a/colossalai/zero/legacy/init_ctx/init_context.py
+++ b/colossalai/legacy/zero/init_ctx/init_context.py
@@ -8,15 +8,15 @@
import torch.distributed as dist
import torch.nn as nn
-from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.zero.shard_utils import BaseShardStrategy
+from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
+from colossalai.legacy.zero.sharded_model.sharded_model_v2 import ShardedModelV2
+from colossalai.legacy.zero.sharded_param import ShardedParamV2
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
-from colossalai.zero.legacy.shard_utils import BaseShardStrategy
-from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
-from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
-from colossalai.zero.legacy.sharded_param import ShardedParamV2
@dataclass
diff --git a/colossalai/zero/legacy/shard_utils/__init__.py b/colossalai/legacy/zero/shard_utils/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/shard_utils/__init__.py
rename to colossalai/legacy/zero/shard_utils/__init__.py
diff --git a/colossalai/zero/legacy/shard_utils/base_shard_strategy.py b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py
similarity index 90%
rename from colossalai/zero/legacy/shard_utils/base_shard_strategy.py
rename to colossalai/legacy/zero/shard_utils/base_shard_strategy.py
index 7ca951091640..9fb80f57ae77 100644
--- a/colossalai/zero/legacy/shard_utils/base_shard_strategy.py
+++ b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py
@@ -3,7 +3,7 @@
import torch.distributed as dist
-from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
+from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):
diff --git a/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
similarity index 97%
rename from colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py
rename to colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
index d663104831ce..1f7baad57816 100644
--- a/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py
+++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
@@ -4,8 +4,8 @@
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
+from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
-from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
from .tensor_shard_strategy import TensorShardStrategy
diff --git a/colossalai/zero/legacy/shard_utils/commons.py b/colossalai/legacy/zero/shard_utils/commons.py
similarity index 100%
rename from colossalai/zero/legacy/shard_utils/commons.py
rename to colossalai/legacy/zero/shard_utils/commons.py
diff --git a/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
similarity index 90%
rename from colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py
rename to colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
index d1df4803b820..cc43907f6655 100644
--- a/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py
+++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
@@ -3,11 +3,11 @@
import torch
import torch.distributed as dist
+from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
+from colossalai.legacy.zero.shard_utils import BaseShardStrategy
+from colossalai.legacy.zero.shard_utils.commons import get_shard
+from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
-from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
-from colossalai.zero.legacy.shard_utils import BaseShardStrategy
-from colossalai.zero.legacy.shard_utils.commons import get_shard
-from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
diff --git a/colossalai/zero/legacy/sharded_model/__init__.py b/colossalai/legacy/zero/sharded_model/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/sharded_model/__init__.py
rename to colossalai/legacy/zero/sharded_model/__init__.py
diff --git a/colossalai/zero/legacy/sharded_model/_utils.py b/colossalai/legacy/zero/sharded_model/_utils.py
similarity index 97%
rename from colossalai/zero/legacy/sharded_model/_utils.py
rename to colossalai/legacy/zero/sharded_model/_utils.py
index f1d642cf3f13..b8a618ef5a0d 100644
--- a/colossalai/zero/legacy/sharded_model/_utils.py
+++ b/colossalai/legacy/zero/sharded_model/_utils.py
@@ -3,7 +3,7 @@
import torch
import torch.nn.functional as F
-from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
+from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float:
diff --git a/colossalai/zero/legacy/sharded_model/reduce_scatter.py b/colossalai/legacy/zero/sharded_model/reduce_scatter.py
similarity index 100%
rename from colossalai/zero/legacy/sharded_model/reduce_scatter.py
rename to colossalai/legacy/zero/sharded_model/reduce_scatter.py
diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py
similarity index 97%
rename from colossalai/zero/legacy/sharded_model/sharded_model_v2.py
rename to colossalai/legacy/zero/sharded_model/sharded_model_v2.py
index e7064277fb3c..91c21ccf9516 100644
--- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py
+++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py
@@ -11,20 +11,20 @@
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.utils.memory import colo_device_memory_capacity
+from colossalai.legacy.zero.gemini.ophooks import register_ophooks_recursively
+from colossalai.legacy.zero.gemini.paramhooks import BaseParamHookMgr
+from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
+from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
+from colossalai.legacy.zero.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
+from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_cpu
+from colossalai.legacy.zero.shard_utils import BaseShardStrategy
+from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device
-from colossalai.utils.memory import colo_device_memory_capacity
-from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
-from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively
-from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
-from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
-from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
-from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
-from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu
-from colossalai.zero.legacy.shard_utils import BaseShardStrategy
-from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer
+from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from ._utils import (
cast_float_arguments,
diff --git a/colossalai/zero/legacy/sharded_model/utils.py b/colossalai/legacy/zero/sharded_model/utils.py
similarity index 92%
rename from colossalai/zero/legacy/sharded_model/utils.py
rename to colossalai/legacy/zero/sharded_model/utils.py
index 08806e78ea3b..7a411669900b 100644
--- a/colossalai/zero/legacy/sharded_model/utils.py
+++ b/colossalai/legacy/zero/sharded_model/utils.py
@@ -2,7 +2,7 @@
import torch
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
+from colossalai.legacy.zero.sharded_model import ShardedModelV2
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
diff --git a/colossalai/zero/legacy/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py
similarity index 94%
rename from colossalai/zero/legacy/sharded_model/zero_hook.py
rename to colossalai/legacy/zero/sharded_model/zero_hook.py
index 1815bee3a9e0..3fc373e5ca44 100644
--- a/colossalai/zero/legacy/sharded_model/zero_hook.py
+++ b/colossalai/legacy/zero/sharded_model/zero_hook.py
@@ -4,13 +4,13 @@
import torch.distributed as dist
from colossalai.legacy.registry import OPHOOKS
+from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
+from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
+from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
+from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
-from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
-from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
-from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
-from colossalai.zero.legacy.shard_utils import BaseShardStrategy
@OPHOOKS.register_module
diff --git a/colossalai/zero/legacy/sharded_optim/__init__.py b/colossalai/legacy/zero/sharded_optim/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/sharded_optim/__init__.py
rename to colossalai/legacy/zero/sharded_optim/__init__.py
diff --git a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py
similarity index 97%
rename from colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py
rename to colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py
index 41dd174cb65a..e21f1cea04df 100644
--- a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py
+++ b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py
@@ -12,15 +12,15 @@
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.interface import OptimizerWrapper
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState
+from colossalai.legacy.zero.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
+from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
+from colossalai.legacy.zero.sharded_model import ShardedModelV2
+from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
-from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
-from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32
class OptimState(Enum):
@@ -28,7 +28,7 @@ class OptimState(Enum):
UNSCALED = 2
-class ShardedOptimizerV2(ColossalaiOptimizer):
+class ShardedOptimizerV2(OptimizerWrapper):
"""A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO).
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
diff --git a/colossalai/zero/legacy/sharded_param/__init__.py b/colossalai/legacy/zero/sharded_param/__init__.py
similarity index 100%
rename from colossalai/zero/legacy/sharded_param/__init__.py
rename to colossalai/legacy/zero/sharded_param/__init__.py
diff --git a/colossalai/zero/legacy/sharded_param/sharded_param.py b/colossalai/legacy/zero/sharded_param/sharded_param.py
similarity index 96%
rename from colossalai/zero/legacy/sharded_param/sharded_param.py
rename to colossalai/legacy/zero/sharded_param/sharded_param.py
index 4bcc4b62104a..454a722cf7e7 100644
--- a/colossalai/zero/legacy/sharded_param/sharded_param.py
+++ b/colossalai/legacy/zero/sharded_param/sharded_param.py
@@ -2,8 +2,8 @@
import torch
-from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
-from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage
+from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState
+from colossalai.legacy.zero.gemini.tensor_utils import colo_tensor_mem_usage
from .sharded_tensor import ShardedTensor
diff --git a/colossalai/zero/legacy/sharded_param/sharded_tensor.py b/colossalai/legacy/zero/sharded_param/sharded_tensor.py
similarity index 94%
rename from colossalai/zero/legacy/sharded_param/sharded_tensor.py
rename to colossalai/legacy/zero/sharded_param/sharded_tensor.py
index af60312600f2..43c7576b93b5 100644
--- a/colossalai/zero/legacy/sharded_param/sharded_tensor.py
+++ b/colossalai/legacy/zero/sharded_param/sharded_tensor.py
@@ -1,6 +1,6 @@
import torch
-from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
+from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor):
diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py
index f9abe4a2a2b6..fd05ddf1d50f 100644
--- a/colossalai/logging/logger.py
+++ b/colossalai/logging/logger.py
@@ -134,8 +134,6 @@ def info(self, message: str, ranks: List[int] = None) -> None:
Args:
message (str): The message to be logged.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
- The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
@@ -147,8 +145,6 @@ def warning(self, message: str, ranks: List[int] = None) -> None:
Args:
message (str): The message to be logged.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
- The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
@@ -160,8 +156,6 @@ def debug(self, message: str, ranks: List[int] = None) -> None:
Args:
message (str): The message to be logged.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
- The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
@@ -173,8 +167,6 @@ def error(self, message: str, ranks: List[int] = None) -> None:
Args:
message (str): The message to be logged.
- parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
- The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py
index edd986ef5e82..9aeab9f44a6d 100644
--- a/colossalai/nn/layer/__init__.py
+++ b/colossalai/nn/layer/__init__.py
@@ -1,2 +1,2 @@
-from .moe import *
+# from .moe import *
from .utils import *
diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py
index 56b11f4d9e08..712d872bb921 100644
--- a/colossalai/nn/layer/moe/experts.py
+++ b/colossalai/nn/layer/moe/experts.py
@@ -6,10 +6,10 @@
import torch.distributed as dist
import torch.nn as nn
-from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.legacy.context import ParallelMode, seed
+from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator
from colossalai.utils import get_current_device
-from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator
class MoeExperts(nn.Module):
diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py
index 03f55d91f3a8..9293d3208f11 100644
--- a/colossalai/nn/layer/moe/layers.py
+++ b/colossalai/nn/layer/moe/layers.py
@@ -6,6 +6,7 @@
import torch.nn.functional as F
from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from colossalai.nn.layer.moe._operation import (
COL_MOE_KERNEL_FLAG,
AllGather,
@@ -18,7 +19,6 @@
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
from colossalai.utils import get_current_device
-from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator
@no_shard_zero_decrator(is_replicated=True)
diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py
index ee2add48ab91..7c6fb099d272 100644
--- a/colossalai/nn/loss/__init__.py
+++ b/colossalai/nn/loss/__init__.py
@@ -1 +1 @@
-from .loss_moe import MoeCrossEntropyLoss, MoeLoss
+# from .loss_moe import MoeCrossEntropyLoss, MoeLoss
diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py
index 06072648beba..7e310793f515 100644
--- a/colossalai/nn/optimizer/__init__.py
+++ b/colossalai/nn/optimizer/__init__.py
@@ -1,10 +1,9 @@
-from .colossalai_optimizer import ColossalaiOptimizer
+from .cpu_adam import CPUAdam
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
from .fused_sgd import FusedSGD
+from .hybrid_adam import HybridAdam
from .lamb import Lamb
from .lars import Lars
-from .cpu_adam import CPUAdam
-from .hybrid_adam import HybridAdam
-__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam']
+__all__ = ['FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam']
diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py
deleted file mode 100644
index 34f5a9541975..000000000000
--- a/colossalai/nn/optimizer/colossalai_optimizer.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import torch
-import torch.nn as nn
-from torch import Tensor
-from torch.optim import Optimizer
-from colossalai.utils import clip_grad_norm_fp32
-
-
-class ColossalaiOptimizer(Optimizer):
-
- def __init__(self, optim: Optimizer):
- self.optim = optim
-
- @property
- def param_groups(self):
- return self.optim.param_groups
-
- @property
- def defaults(self):
- return self.optim.defaults
-
- def add_param_group(self, *args, **kwargs):
- return self.optim.add_param_group(*args, **kwargs)
-
- def step(self, *args, **kwargs):
- return self.optim.step(*args, **kwargs)
-
- def zero_grad(self, *args, **kwargs):
- self.optim.zero_grad(*args, **kwargs)
-
- def load_state_dict(self, *args, **kwargs):
- self.optim.load_state_dict(*args, **kwargs)
-
- def state_dict(self):
- return self.optim.state_dict()
-
- def backward(self, loss: Tensor):
- loss.backward()
-
- def backward_by_grad(self, tensor: Tensor, grad: Tensor):
- torch.autograd.backward(tensors=tensor, grad_tensors=grad)
-
- def clip_grad_norm(self, model: nn.Module, max_norm: float):
- if max_norm > 0.0:
- clip_grad_norm_fp32(model.parameters(), max_norm)
diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py
index 0fcde9707646..e88a1f00a1b7 100644
--- a/colossalai/pipeline/__init__.py
+++ b/colossalai/pipeline/__init__.py
@@ -1,4 +1,11 @@
-from .pipelinable import PipelinableContext, PipelinableModel
-from .layer_spec import LayerSpec
+from .p2p import PipelineP2PCommunication
+from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
+from .stage_manager import PipelineStageManager
-__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
\ No newline at end of file
+__all__ = [
+ 'PipelineSchedule',
+ 'OneForwardOneBackwardSchedule',
+ 'InterleavedSchedule',
+ 'PipelineP2PCommunication',
+ 'PipelineStageManager',
+]
diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py
deleted file mode 100644
index 79e19f9eaf77..000000000000
--- a/colossalai/pipeline/middleware/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal
-
-__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
\ No newline at end of file
diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py
deleted file mode 100644
index 9d9e9d44f46c..000000000000
--- a/colossalai/pipeline/rpc/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
-from .utils import pytree_map
-
-__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
\ No newline at end of file
diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py
index 8b13413b1a31..07c0f5927060 100644
--- a/colossalai/pipeline/schedule/__init__.py
+++ b/colossalai/pipeline/schedule/__init__.py
@@ -1,7 +1,9 @@
from .base import PipelineSchedule
+from .interleaved_pp import InterleavedSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule
__all__ = [
'PipelineSchedule',
'OneForwardOneBackwardSchedule',
+ 'InterleavedSchedule',
]
diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py
index b2da64e6c33a..099376d931e8 100644
--- a/colossalai/tensor/__init__.py
+++ b/colossalai/tensor/__init__.py
@@ -1,18 +1,11 @@
-from . import distspec
from .colo_parameter import ColoParameter
from .colo_tensor import ColoTensor
from .comm_spec import CollectiveCommPattern, CommSpec
-from .compute_spec import ComputePattern, ComputeSpec
-from .dist_spec_mgr import DistSpecManager
-from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
-from .process_group import ProcessGroup
-from .tensor_spec import ColoTensorSpec
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
__all__ = [
- 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
- 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
- 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
+ 'ColoTensor', 'convert_parameter', 'named_params_with_colotensor', 'ColoParameter', 'ColoParamOpHook',
+ 'ColoParamOpHookManager', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
'merge_same_dim_mesh_list'
]
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index 6f9717d353e6..5226f688b43b 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -1,79 +1,32 @@
-from .activation_checkpoint import checkpoint
-from .checkpointing import load_checkpoint, save_checkpoint
from .common import (
_cast_float,
- clip_grad_norm_fp32,
conditional_context,
- copy_tensor_parallel_attributes,
- count_zeros_fp32,
disposable,
ensure_path_exists,
free_storage,
is_ddp_ignored,
- is_dp_rank_0,
- is_model_parallel_parameter,
- is_no_pp_or_last_stage,
- is_tp_rank_0,
- is_using_ddp,
- is_using_pp,
- is_using_sequence,
- multi_tensor_applier,
- param_is_not_tensor_parallel_duplicate,
- print_rank_0,
- switch_virtual_pipeline_parallel_rank,
- sync_model_param,
-)
-from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
-from .data_sampler import DataParallelSampler, get_dataloader
-from .memory import (
- colo_device_memory_capacity,
- colo_device_memory_used,
- colo_get_cpu_memory_capacity,
- colo_set_cpu_memory_capacity,
- colo_set_process_memory_fraction,
- report_memory_usage,
+ set_seed,
)
+from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize
+from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer
__all__ = [
- 'checkpoint',
- 'print_rank_0',
- 'sync_model_param',
- 'is_ddp_ignored',
- 'is_dp_rank_0',
- 'is_tp_rank_0',
- 'is_no_pp_or_last_stage',
- 'is_using_ddp',
- 'is_using_pp',
- 'is_using_sequence',
'conditional_context',
- 'is_model_parallel_parameter',
- 'clip_grad_norm_fp32',
- 'count_zeros_fp32',
- 'copy_tensor_parallel_attributes',
- 'param_is_not_tensor_parallel_duplicate',
'get_current_device',
'synchronize',
'empty_cache',
'set_to_cuda',
- 'report_memory_usage',
- 'colo_device_memory_capacity',
- 'colo_device_memory_used',
- 'colo_set_process_memory_fraction',
'Timer',
'MultiTimer',
'multi_tensor_applier',
- 'DataParallelSampler',
- 'get_dataloader',
- 'switch_virtual_pipeline_parallel_rank',
'TensorDetector',
- 'load_checkpoint',
- 'save_checkpoint',
'ensure_path_exists',
'disposable',
- 'colo_set_cpu_memory_capacity',
- 'colo_get_cpu_memory_capacity',
'_cast_float',
'free_storage',
+ 'set_seed',
+ 'is_ddp_ignored',
+ 'set_device',
]
diff --git a/colossalai/utils/checkpoint/__init__.py b/colossalai/utils/checkpoint/__init__.py
deleted file mode 100644
index 1795b4ce36f4..000000000000
--- a/colossalai/utils/checkpoint/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .module_checkpoint import save_checkpoint, load_checkpoint
-
-__all__ = ['save_checkpoint', 'load_checkpoint']
diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py
index 998901708239..8c769c5b13c0 100644
--- a/colossalai/utils/common.py
+++ b/colossalai/utils/common.py
@@ -3,44 +3,12 @@
import functools
import os
import random
-import socket
-from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable
+import numpy as np
import torch
-import torch.distributed as dist
-from torch import inf
-from torch.nn.parameter import Parameter
-
-from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.tensor import ColoParameter, ProcessGroup
-
-from .multi_tensor_apply import multi_tensor_applier
-
-try:
- from colossalai._C import fused_optim
-except:
- fused_optim = None
-
-
-def print_rank_0(msg: str, logger=None):
- """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
-
- Args:
- msg (str): A string message to output.
- logger (:class:`colossalai.logging.DistributedLogger`, optional):
- The logger to record the message, defaults to None.
- """
- if gpc.get_global_rank() == 0:
- if logger is None:
- print(msg, flush=True)
- else:
- logger.info(msg)
def ensure_path_exists(filename: str):
@@ -50,47 +18,6 @@ def ensure_path_exists(filename: str):
Path(dirpath).mkdir(parents=True, exist_ok=True)
-def sync_model_param(model, parallel_mode):
- r"""Make sure data parameters are consistent during Data Parallel Mode.
-
- Args:
- model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
- parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel mode to be checked.
-
- Note:
- The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
- in `parallel_mode `_
- """
- if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
- for param in model.parameters():
- ranks = gpc.get_ranks_in_group(parallel_mode)
- dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
-
-
-def is_dp_rank_0():
- return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
-
-
-def is_tp_rank_0():
- return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
-
-
-def is_no_pp_or_last_stage():
- return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
-
-
-def is_using_ddp():
- return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1
-
-
-def is_using_pp():
- return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
-
-
-def is_using_sequence():
- return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
-
-
@contextmanager
def conditional_context(context_manager, enable=True):
if enable:
@@ -100,365 +27,10 @@ def conditional_context(context_manager, enable=True):
yield
-class model_branch_context(object):
-
- def __enter__(self):
- self.env_status = env.save()
-
- def __exit__(self, *exc_info):
- env.load(**self.env_status)
-
-
-def is_model_parallel_parameter(p):
- return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
-
-
def is_ddp_ignored(p):
return getattr(p, '_ddp_to_ignore', False)
-def _calc_l2_norm(grads):
- # we should not
- global fused_optim
-
- if fused_optim is None:
- from colossalai.kernel.op_builder import FusedOptimBuilder
- fused_optim = FusedOptimBuilder().load()
-
- norm = 0.0
- if len(grads) > 0:
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- norm, _ = multi_tensor_applier(
- fused_optim.multi_tensor_l2norm,
- dummy_overflow_buf,
- [grads],
- False # no per-parameter norm
- )
- return norm
-
-
-def _calc_lp(grads, norm_type):
- norm = 0.0
- for grad in grads:
- grad_norm = torch.norm(grad, norm_type)
- norm += grad_norm**norm_type
- return norm
-
-
-def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
- if torch.is_tensor(norm) and norm.device.type != 'cuda':
- norm = norm.to(torch.cuda.current_device())
- return norm
-
-
-def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
- if isinstance(norm, float):
- norm = torch.Tensor([norm])
- if move_to_cuda:
- norm = norm.to(torch.cuda.current_device())
- return norm
-
-
-# ======== Gradient Clipping =========
-
-
-def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
- if len(params) == 0:
- return 0.0
- grads = [p.grad for p in params]
- use_cuda_kernel = grads[0].device.type == 'cuda'
- if norm_type == inf:
- local_lp = max([g.abs().max() for g in grads])
- elif norm_type == 2.0 and use_cuda_kernel:
- local_lp = _calc_l2_norm(grads)**norm_type
- else:
- local_lp = _calc_lp(grads, norm_type)
- if isinstance(local_lp, torch.Tensor):
- return local_lp.item()
- return local_lp
-
-
-def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
- if len(params) == 0:
- return 0.0
- buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
- for p in params:
- if p.is_replicate():
- buckets[None].append(p)
- else:
- buckets[p.get_process_group().tp_process_group()].append(p)
- total_lp = 0.0
- for group, bucket in buckets.items():
- local_lp = _compute_local_lp(bucket, norm_type)
- if group is not None:
- local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
- if norm_type == inf:
- dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
- else:
- dist.all_reduce(local_lp_tensor, group=group)
- local_lp = local_lp_tensor.item()
- if norm_type == inf:
- total_lp = max(total_lp, local_lp)
- else:
- total_lp += local_lp
- return total_lp
-
-
-def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
- if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
- total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
- if norm_type == inf:
- dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
- else:
- dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
- total_lp = total_lp_tensor.item()
- return total_lp
-
-
-def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- grad_dtype = None
- cpu_grad_params: List[ColoParameter] = []
- cuda_grad_params: List[ColoParameter] = []
- for p in parameters:
- if p.grad is None:
- continue
- assert isinstance(p, ColoParameter)
- if grad_dtype is None:
- grad_dtype = p.grad.dtype
- assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
- if p.grad.device.type == 'cuda':
- cuda_grad_params.append(p)
- else:
- cpu_grad_params.append(p)
- norm_type = float(norm_type)
- cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
- cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
- if norm_type == inf:
- total_lp = max(cpu_lp, cuda_lp)
- else:
- total_lp = cpu_lp + cuda_lp
- return _compute_pp_grad_lp(total_lp, norm_type)
-
-
-def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
- norm_type = float(norm_type)
- total_norm = _compute_grad_lp(parameters, norm_type)
- if norm_type != inf:
- total_norm = total_norm**(1 / norm_type)
- return total_norm
-
-
-def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
- clip_coef = max_norm / (total_norm + 1e-6)
- if clip_coef < 1.0:
- cuda_grads: List[torch.Tensor] = []
- cpu_grads: List[torch.Tensor] = []
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- for p in parameters:
- if p.grad is None:
- continue
- if p.grad.device.type == 'cuda':
- cuda_grads.append(p.grad.detach())
- else:
- cpu_grads.append(p.grad.detach())
- if len(cuda_grads) > 0:
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads],
- clip_coef)
- for g in cpu_grads:
- g.mul_(clip_coef)
-
-
-def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
- total_norm = compute_grad_norm(parameters, norm_type)
- _clip_grad_norm(parameters, max_norm, total_norm)
- return total_norm
-
-
-def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
- """Clips gradient norm of an iterable of parameters whose gradients are in fp32.
-
- This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
- added functionality to handle model parallel parameters.
-
- Note:
- the gradients are modified in place.
-
- Args:
- parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`):
- An iterable of Tensors or a single Tensor that will have gradients normalized.
- max_norm (Union[float, int]): Max norm of the gradients.
- norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm.
-
- Returns:
- float: Total norm of the parameters.
- """
-
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
-
- # Filter parameters based on:
- # - grad should not be none
- # - parameter should not be shared
- # - should not be a replica due to tensor model parallelism
- params: List[Parameter] = []
- has_zero_shared_param: bool = False
- for param in parameters:
- if param.grad is not None:
- # Make sure the grads are in fp32
- assert param.grad.dtype == torch.float, \
- f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
- if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
- has_zero_shared_param = True
- params.append(param)
-
- if len(params) == 0:
- enable_cuda_kernels = False
- else:
- enable_cuda_kernels = params[0].grad.device.type == 'cuda'
- # Norm parameters.
- max_norm = float(max_norm)
- norm_type = float(norm_type)
-
- # Parameters can be on CPU or CUDA
- # If parameters are on CPU, disable CUDA kernels
-
- # Calculate norm.
- if norm_type == inf:
- total_norm = max(p.grad.data.abs().max() for p in params)
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- # Take max across all model-parallel GPUs.
- if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
- dist.all_reduce(total_norm_cuda,
- op=dist.ReduceOp.MAX,
- group=gpc.get_group(ParallelMode.MODEL),
- async_op=False)
- if has_zero_shared_param:
- dist.all_reduce(total_norm_cuda,
- op=dist.ReduceOp.MAX,
- group=gpc.get_group(ParallelMode.DATA),
- async_op=False)
- total_norm = total_norm_cuda[0].item()
- else:
- tensor_parallel_grads = []
- no_tensor_parallel_grads = []
- zero_sharded_grads = []
- for p in params:
- if is_model_parallel_parameter(p):
- reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
- tensor_parallel_grads.append(p.grad.data / reductor)
- elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
- zero_sharded_grads.append(p.grad.data)
- else:
- no_tensor_parallel_grads.append(p.grad.data)
-
- if norm_type == 2.0 and enable_cuda_kernels:
- tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
- no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
- zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
- else:
- tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
- no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
- zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
- # If norm is type of float, then we convert them into torch.Tensor.
- tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
- no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels)
- zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels)
- # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
- if not enable_cuda_kernels:
- tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
- no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
- zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
-
- # Sum across all model-parallel GPUs.
- if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
- dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
- # Sum across all zero sharded GPUs
- if len(zero_sharded_grads) > 0:
- dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
- no_tensor_parallel_norm += zero_sharded_norm
- total_norm = tensor_parallel_norm + no_tensor_parallel_norm
- if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
- dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
- total_norm = total_norm**(1.0 / norm_type)
- if torch.is_tensor(total_norm):
- total_norm = total_norm.item()
-
- # Scale.
- clip_coeff = max_norm / (total_norm + 1.0e-6)
- if clip_coeff < 1.0:
- if enable_cuda_kernels:
- grads = [p.grad.detach() for p in params]
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
- else:
- for p in params:
- p.grad.detach().mul_(clip_coeff)
- return total_norm
-
-
-def count_zeros_fp32(parameters):
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
-
- # Filter parameters based on:
- # - grad should not be none
- # - parameter should not be shared
- # - should not be a replica due to tensor model parallelism
- total_num_zeros = 0.0
- for param in parameters:
- grad_not_none = param.grad is not None
- is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
- if grad_not_none and is_not_tp_duplicate:
- grad = param.grad.detach()
- num_zeros = grad.numel() - torch.count_nonzero(grad)
- total_num_zeros = num_zeros + total_num_zeros
-
- total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
-
- # Sum across all model-parallel GPUs.
- ops = []
- ops.append(
- dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
- if gpc.is_initialized(ParallelMode.PIPELINE):
- ops.append(
- dist.all_reduce(total_num_zeros,
- op=dist.ReduceOp.SUM,
- group=gpc.get_group(ParallelMode.PIPELINE),
- async_op=True))
-
- for req in ops:
- req.wait()
- total_num_zeros = total_num_zeros.item()
-
- return total_num_zeros
-
-
-def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
- for attr in TENSOR_PARALLEL_ATTRIBUTES:
- if hasattr(src_tensor, attr):
- val = getattr(src_tensor, attr)
- setattr(dst_tensor, attr, val)
-
-
-def param_is_not_tensor_parallel_duplicate(param):
- return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
- ParallelMode.TENSOR) == 0)
-
-
-@contextmanager
-def switch_virtual_pipeline_parallel_rank(rank):
- prev_rank = gpc.virtual_pipeline_parallel_rank
- try:
- gpc.set_virtual_pipeline_parallel_rank(rank)
- yield
- finally:
- gpc.set_virtual_pipeline_parallel_rank(prev_rank)
-
-
def disposable(func: Callable) -> Callable:
executed = False
@@ -489,3 +61,9 @@ def _cast_float(args, dtype: torch.dtype):
elif isinstance(args, dict):
args = {k: _cast_float(v, dtype) for k, v in args.items()}
return args
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py
index 60f3ccb60883..6b5d17cf04e7 100644
--- a/colossalai/utils/cuda.py
+++ b/colossalai/utils/cuda.py
@@ -1,7 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+from typing import Optional
+
import torch
+import torch.distributed as dist
def set_to_cuda(models):
@@ -23,7 +26,7 @@ def set_to_cuda(models):
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
- If cuda available, return gpu, otherwise return cpu.
+ If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f'cuda:{torch.cuda.current_device()}')
@@ -45,3 +48,9 @@ def empty_cache():
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
+
+
+def set_device(index: Optional[int] = None) -> None:
+ if index is None:
+ index = dist.get_rank() % torch.cuda.device_count()
+ torch.cuda.set_device(index)
diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py
index 86d04c11958b..6456dfb905b0 100644
--- a/colossalai/utils/moe.py
+++ b/colossalai/utils/moe.py
@@ -1,52 +1,54 @@
-import torch.nn as nn
-import torch.distributed as dist
-from colossalai.core import global_context as gpc
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.context import ParallelMode
-from .common import is_using_ddp
-from typing import Dict, List
-
-
-def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
- """Returns a parameter dictionary, the key of which is the expert parallel
- size of every parameter. Since the parameters in data parallelism is replicated
- in each GPU, we set their ep_size to 1.
-
- Args:
- model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
- """
- epsize_param_dict = dict()
- for param in model.parameters():
- if not hasattr(param, 'moe_info'):
- ep_size = 1 # set ep_size to 1 for dp parameters
- else:
- ep_size = param.moe_info.ep_size
- if ep_size not in epsize_param_dict:
- epsize_param_dict[ep_size] = []
- epsize_param_dict[ep_size].append(param)
-
- return epsize_param_dict
-
-
-def sync_moe_model_param(model: nn.Module):
- """Make sure model parameters are consistent in MoE parallel context.
-
- Args:
- model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
- """
- if is_using_ddp():
-
- param_dict = get_moe_epsize_param_dict(model)
-
- # synchronize the parameters whose dp_group is the whole world
- if 1 in param_dict:
- src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
- for param in param_dict[1]:
- dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA))
-
- for ep_size in param_dict:
- # When ep_size = world_size, communication is not needed
- if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
- src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group)
- for param in param_dict[ep_size]:
- dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
+from typing import Dict, List
+
+import torch.distributed as dist
+import torch.nn as nn
+
+from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.utils import is_using_ddp
+
+
+def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
+ """Returns a parameter dictionary, the key of which is the expert parallel
+ size of every parameter. Since the parameters in data parallelism is replicated
+ in each GPU, we set their ep_size to 1.
+
+ Args:
+ model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
+ """
+ epsize_param_dict = dict()
+ for param in model.parameters():
+ if not hasattr(param, 'moe_info'):
+ ep_size = 1 # set ep_size to 1 for dp parameters
+ else:
+ ep_size = param.moe_info.ep_size
+ if ep_size not in epsize_param_dict:
+ epsize_param_dict[ep_size] = []
+ epsize_param_dict[ep_size].append(param)
+
+ return epsize_param_dict
+
+
+def sync_moe_model_param(model: nn.Module):
+ """Make sure model parameters are consistent in MoE parallel context.
+
+ Args:
+ model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
+ """
+ if is_using_ddp():
+
+ param_dict = get_moe_epsize_param_dict(model)
+
+ # synchronize the parameters whose dp_group is the whole world
+ if 1 in param_dict:
+ src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
+ for param in param_dict[1]:
+ dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA))
+
+ for ep_size in param_dict:
+ # When ep_size = world_size, communication is not needed
+ if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
+ src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group)
+ for param in param_dict[ep_size]:
+ dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py
index dad852a34a71..549635af4332 100644
--- a/colossalai/zero/gemini/colo_init_context.py
+++ b/colossalai/zero/gemini/colo_init_context.py
@@ -3,7 +3,8 @@
import torch
from torch import nn
-from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
+from colossalai.legacy.tensor import ProcessGroup
+from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
# find named_params includes replica
diff --git a/colossalai/zero/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py
index 02c9d5754ec9..e1fe904ebf1a 100644
--- a/colossalai/zero/gemini/memory_tracer/__init__.py
+++ b/colossalai/zero/gemini/memory_tracer/__init__.py
@@ -3,9 +3,8 @@
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
-from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = [
- 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
- 'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator'
+ 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats',
+ 'OrderedParamGenerator'
]
diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
index 83903bbf4023..b93ad2c44104 100644
--- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
+++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
@@ -1,7 +1,6 @@
from typing import Optional
from colossalai.utils import get_current_device
-from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import ChunkManager
from .memory_stats import MemStats
@@ -33,4 +32,5 @@ def record_model_data_volume(self) -> None:
@property
def cuda_margin_mem(self) -> float:
+ from colossalai.legacy.utils.memory import colo_device_memory_capacity
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py
index 4bb585677d5b..2a65d4b55409 100644
--- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py
+++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py
@@ -5,7 +5,7 @@
import torch
-from colossalai.utils import colo_device_memory_used, get_current_device
+from colossalai.utils import get_current_device
class MemoryMonitor:
@@ -110,6 +110,7 @@ def finish(self):
return max_usage
def _measure_usage(self):
+ from colossalai.legacy.utils import colo_device_memory_used
max_usage = 0
while self.keep_measuring:
max_usage = max(
diff --git a/colossalai/zero/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py
index 0694be48550a..abb3dcc74b27 100644
--- a/colossalai/zero/gemini/memory_tracer/memstats_collector.py
+++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py
@@ -70,7 +70,7 @@ def record_model_data_volume(self) -> None:
Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
- from colossalai.zero.legacy.gemini import StatefulTensor
+ from colossalai.legacy.zero.gemini import StatefulTensor
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
index e5466965cc48..6656821fef74 100644
--- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
+++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
@@ -1,12 +1,12 @@
import torch.nn
-from colossalai.tensor.param_op_hook import ColoParamOpHookManager
-from colossalai.utils import _cast_float
-from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
+from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
+from colossalai.tensor.param_op_hook import ColoParamOpHookManager
+from colossalai.utils import _cast_float
from .memory_stats import MemStats
diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py
index cd775da5e11f..a35529723a68 100644
--- a/colossalai/zero/gemini/placement_policy.py
+++ b/colossalai/zero/gemini/placement_policy.py
@@ -6,8 +6,8 @@
import torch
+from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
-from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py
index 4205a9891534..ece92fe02e28 100644
--- a/colossalai/zero/low_level/_utils.py
+++ b/colossalai/zero/low_level/_utils.py
@@ -7,9 +7,6 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
-from colossalai.tensor import ColoParameter
-from colossalai.utils import is_model_parallel_parameter
-
def flatten(input_):
return _flatten_dense_tensors(input_)
diff --git a/docs/README.md b/docs/README.md
index f0cb50ffe217..a5ae2ce96a99 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -108,5 +108,5 @@ We support `autodoc` to extract the docstring and transform it into a Web elemen
You just need to add `{{ autodoc: }}` in your markdown as a single line. An example is given below and you can see the outcome in [this PR](https://github.com/hpcaitech/ColossalAI-Documentation/pull/175).
```markdown
-{{ autodoc:colossalai.amp.apex_amp.convert_to_apex_amp }}
+{{ autodoc:colossalai.legacy.amp.apex_amp.convert_to_apex_amp }}
```
diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md
index 384221596885..63434a526228 100644
--- a/docs/source/en/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/en/advanced_tutorials/add_your_parallel.md
@@ -31,7 +31,7 @@ global context for users to easily manage their process groups. If you wish to a
define a new class and set it in your configuration file. To define your own way of creating process groups, you can
follow the steps below to create a new distributed initialization.
-1. Add your parallel mode in `colossalai.context.parallel_mode.ParallelMode`.
+1. Add your parallel mode in `colossalai.legacy.context.parallel_mode.ParallelMode`.
```python
class ParallelMode(Enum):
GLOBAL = 'global'
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 36c94fb492cd..0218264cc258 100644
--- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -37,7 +37,7 @@ import torch.nn as nn
from colossalai import nn as col_nn
from colossalai.amp import AMP_TYPE
from colossalai.legacy.builder.pipeline import partition_uniform
-from colossalai.context.parallel_mode import ParallelMode
+from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
diff --git a/docs/source/en/basics/command_line_tool.md b/docs/source/en/basics/command_line_tool.md
index 48b199cf78e9..4c278aaa0c6a 100644
--- a/docs/source/en/basics/command_line_tool.md
+++ b/docs/source/en/basics/command_line_tool.md
@@ -30,24 +30,4 @@ This command will inform you information regarding the version compatibility and
To launch distributed jobs on single or multiple nodes, the command `colossalai run` can be used for process launching.
You may refer to [Launch Colossal-AI](./launch_colossalai.md) for more details.
-## Tensor Parallel Micro-Benchmarking
-
-As Colossal-AI provides an array of tensor parallelism methods, it is not intuitive to choose one for your hardware and
-model. Therefore, we provide a simple benchmarking to evaluate the performance of various tensor parallelisms on your system.
-This benchmarking is run on a simple MLP model where the input data is of the shape `(batch_size, seq_length, hidden_size)`.
-Based on the number of GPUs, the CLI will look for all possible tensor parallel configurations and display the benchmarking results.
-You can customize the benchmarking configurations by checking out `colossalai benchmark --help`.
-
-```shell
-# run on 4 GPUs
-colossalai benchmark --gpus 4
-
-# run on 8 GPUs
-colossalai benchmark --gpus 8
-```
-
-:::caution
-
-Only single-node benchmarking is supported currently.
-
-:::
+
diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
index c4b0f6557926..812b9c34e4da 100644
--- a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
@@ -24,7 +24,7 @@
并行通常由进程组来管理,参与相同并行算法的进程被置于同一进程组。对于不同的并行算法,需要创建不同的进程组。
Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管理进程组。如果你想添加新的进程组,你可以很容易地定义一个新的类并在你的配置文件中设置它。为了定义你自己的进程组创建方式,你可以按照下面的步骤来创建一个新的分布式初始化。
-1. 在 `colossalai.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。
+1. 在 `colossalai.legacy.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。
```python
class ParallelMode(Enum):
GLOBAL = 'global'
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 3f57f39f2838..a1d58e9fddc2 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -37,7 +37,7 @@ import torch.nn as nn
from colossalai import nn as col_nn
from colossalai.amp import AMP_TYPE
from colossalai.legacy.builder.pipeline import partition_uniform
-from colossalai.context.parallel_mode import ParallelMode
+from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
diff --git a/docs/source/zh-Hans/basics/command_line_tool.md b/docs/source/zh-Hans/basics/command_line_tool.md
index 9b0275a6cedd..5c4c18989c17 100644
--- a/docs/source/zh-Hans/basics/command_line_tool.md
+++ b/docs/source/zh-Hans/basics/command_line_tool.md
@@ -26,22 +26,4 @@ Colossal-AI给用户提供了命令行工具,目前命令行工具可以用来
在分布式训练时,我们可以使用`colossalai run`来启动单节点或者多节点的多进程,详细的内容可以参考[启动 Colossal-AI](./launch_colossalai.md)。
-## 张量并行基准测试
-
-Colossal-AI提供了多种张量并行,想要充分理解这些方法需要一定的学习成本,对于新手来说很难靠经验选择一个并行方式。
-所以我们提供了一个简单的基准测试,能够让用户在自己的机器上测试不同张量并行的性能。这个基准测试跑一个并行的MLP模型,
-输入数据的维度为`(批大小,序列长度,隐藏层维度)`。通过指定GPU的数量,Colossal-AI会搜索所有可行的并行配置。用户可以通过查看`colossalai benchmark --help`来自定义相关的测试参数。
-
-```shell
-# 使用4个GPU
-colossalai benchmark --gpus 4
-
-# 使用8个GPU
-colossalai benchmark --gpus 8
-```
-
-:::caution
-
-目前仅支持单节点的基准测试。
-
-:::
+
diff --git a/examples/community/roberta/pretraining/pretrain_utils.py b/examples/community/roberta/pretraining/pretrain_utils.py
index cea6ac2c36e5..e6a393a57dda 100644
--- a/examples/community/roberta/pretraining/pretrain_utils.py
+++ b/examples/community/roberta/pretraining/pretrain_utils.py
@@ -16,7 +16,7 @@
get_linear_schedule_with_warmup,
)
-from colossalai.core import global_context as gpc
+from colossalai.legacy.core import global_context as gpc
from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.optimizer import FusedAdam, HybridAdam
diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py
index 53fa9f489c10..fa6457cab328 100644
--- a/examples/community/roberta/pretraining/run_pretraining.py
+++ b/examples/community/roberta/pretraining/run_pretraining.py
@@ -17,7 +17,7 @@
import colossalai
from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.core import global_context as gpc
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
diff --git a/examples/community/roberta/pretraining/utils/exp_util.py b/examples/community/roberta/pretraining/utils/exp_util.py
index 4a2c9d8a47ad..1fcaa428b277 100644
--- a/examples/community/roberta/pretraining/utils/exp_util.py
+++ b/examples/community/roberta/pretraining/utils/exp_util.py
@@ -5,7 +5,7 @@
import psutil
import torch
-from colossalai.core import global_context as gpc
+from colossalai.legacy.core import global_context as gpc
def logging(s, log_path, print_=True, log_=True):
diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh
index 84345f589bb5..b0a96ec70075 100644
--- a/examples/images/dreambooth/test_ci.sh
+++ b/examples/images/dreambooth/test_ci.sh
@@ -1,24 +1,26 @@
#!/bin/bash
set -xe
-pip install -r requirements.txt
+echo "this test is slow"
-HF_DATASETS_OFFLINE=1
-TRANSFORMERS_OFFLINE=1
-DIFFUSERS_OFFLINE=1
+# pip install -r requirements.txt
-# "torch_ddp" "torch_ddp_fp16" "low_level_zero"
-for plugin in "gemini"; do
- torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
- --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
- --instance_data_dir="/data/dreambooth/Teyvat/data" \
- --output_dir="./weight_output" \
- --instance_prompt="a picture of a dog" \
- --resolution=512 \
- --plugin=$plugin \
- --train_batch_size=1 \
- --learning_rate=5e-6 \
- --lr_scheduler="constant" \
- --lr_warmup_steps=0 \
- --test_run=True \
- --num_class_images=200
-done
+# HF_DATASETS_OFFLINE=1
+# TRANSFORMERS_OFFLINE=1
+# DIFFUSERS_OFFLINE=1
+
+# # "torch_ddp" "torch_ddp_fp16" "low_level_zero"
+# for plugin in "gemini"; do
+# torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
+# --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
+# --instance_data_dir="/data/dreambooth/Teyvat/data" \
+# --output_dir="./weight_output" \
+# --instance_prompt="a picture of a dog" \
+# --resolution=512 \
+# --plugin=$plugin \
+# --train_batch_size=1 \
+# --learning_rate=5e-6 \
+# --lr_scheduler="constant" \
+# --lr_warmup_steps=0 \
+# --test_run=True \
+# --num_class_images=200
+# don
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py
index f60704650b7e..9b2ed3b971ae 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai.py
@@ -7,6 +7,7 @@
from typing import Optional
import torch
+import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
@@ -21,13 +22,9 @@
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext
-from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers()
logger = get_dist_logger()
@@ -366,8 +363,8 @@ def main(args):
else:
colossalai.launch_from_torch(config={}, seed=args.seed)
- local_rank = gpc.get_local_rank(ParallelMode.DATA)
- world_size = gpc.get_world_size(ParallelMode.DATA)
+ local_rank = dist.get_rank()
+ world_size = dist.get_world_size()
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
index c98950fd795d..654bce36ccb7 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
@@ -23,8 +23,8 @@
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
index e331fc8fcf10..84b02633e775 100644
--- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
+++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
@@ -7,8 +7,8 @@
from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
-from colossalai.core import global_context as gpc
from colossalai.initialize import launch_from_torch
+from colossalai.legacy.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
BATCH_SIZE = 16
diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
index ad69888b8cc8..30d6aab4f12f 100644
--- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
+++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
@@ -3,7 +3,6 @@
from functools import partial
import torch
-from model_zoo import model_builder
from torch import nn
from tqdm import tqdm
@@ -14,11 +13,12 @@
split_with_split_nodes_pass,
)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
+from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
+from colossalai.legacy.pipeline.rpc.utils import rpc_run
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
-from colossalai.pipeline.middleware.adaptor import get_fx_topology
-from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
-from colossalai.pipeline.rpc.utils import rpc_run
+from model_zoo import model_builder
def parse_args():
diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh
index 57ce6ab64c5b..5eaa4af4df78 100644
--- a/examples/language/gpt/gemini/run_gemini.sh
+++ b/examples/language/gpt/gemini/run_gemini.sh
@@ -9,11 +9,6 @@ export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10}
# export PYTHONPATH=$PWD:$PYTHONPATH
-if [ ${USE_SHARD_INIT} = "True" ]; then
- USE_SHARD_INIT="--shardinit"
-else
- USE_SHARD_INIT=""
-fi
mkdir -p gemini_logs
@@ -22,4 +17,4 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--batch_size=${BATCH_SIZE} \
--distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \
-2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
+2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}.log
diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py
index 347251ca5631..f9d30fd15c7b 100644
--- a/examples/language/gpt/gemini/train_gpt_demo.py
+++ b/examples/language/gpt/gemini/train_gpt_demo.py
@@ -1,3 +1,4 @@
+import argparse
import os
from contextlib import nullcontext
from functools import partial
@@ -9,7 +10,6 @@
from commons.model_zoo import model_builder
from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp
from packaging import version
-from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.booster import Booster
@@ -23,7 +23,7 @@
def parse_args():
- parser = colossalai.get_default_parser()
+ parser = argparse.ArgumentParser()
parser.add_argument(
"--distplan",
type=str,
diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh
index b9e4e43a8d35..db742220d97e 100644
--- a/examples/language/gpt/test_ci.sh
+++ b/examples/language/gpt/test_ci.sh
@@ -2,4 +2,4 @@ set -x
pip install -r requirements.txt
cd gemini && bash test_ci.sh
-cd ../hybridparallelism && bash run.sh
+# cd ../hybridparallelism && bash run.sh
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
index e521193a97da..a6c80394c50f 100644
--- a/examples/language/gpt/titans/model/embed.py
+++ b/examples/language/gpt/titans/model/embed.py
@@ -6,8 +6,8 @@
from torch.nn import functional as F
from torch.nn.parameter import Parameter
-from colossalai.context import ParallelMode, seed
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode, seed
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.base_layer import ParallelLayer
from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py
index 72297c540da1..746acbf7dccd 100644
--- a/examples/language/gpt/titans/model/gpt1d.py
+++ b/examples/language/gpt/titans/model/gpt1d.py
@@ -9,13 +9,13 @@
from colossalai import kernel
from colossalai import nn as col_nn
-from colossalai.core import global_context as gpc
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer import Linear1D_Col, Linear1D_Row
from colossalai.legacy.nn.layer.base_layer import ParallelLayer
from colossalai.legacy.nn.layer.utils import ACT2FN, divide
+from colossalai.legacy.utils.activation_checkpoint import checkpoint
from colossalai.utils import checkpoint
-from colossalai.utils.activation_checkpoint import checkpoint
__all__ = [
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py
index 9b22d156bbcd..a9da246faf82 100644
--- a/examples/language/gpt/titans/model/pipeline_gpt1d.py
+++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py
@@ -7,11 +7,11 @@
from colossalai import kernel
from colossalai import nn as col_nn
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.legacy.pipeline.utils import partition_uniform
from colossalai.logging import get_dist_logger
-from colossalai.pipeline.utils import partition_uniform
from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py
index b239b626c07f..3ed18b21fff5 100644
--- a/examples/language/gpt/titans/train_gpt.py
+++ b/examples/language/gpt/titans/train_gpt.py
@@ -8,14 +8,14 @@
import colossalai
import colossalai.utils as utils
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.trainer import Trainer, hooks
+from colossalai.legacy.zero.init_ctx import ZeroInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
from colossalai.utils.timer import MultiTimer
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
def calc_local_model_size(model: torch.nn.Module):
diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
index a6a9ad0a312c..33aa5990f7c1 100644
--- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
+++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
@@ -4,8 +4,8 @@
import colossalai
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
-from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
+from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
diff --git a/examples/tutorial/auto_parallel/test_ci.sh b/examples/tutorial/auto_parallel/test_ci.sh
index bf6275b673ff..b27e36217117 100644
--- a/examples/tutorial/auto_parallel/test_ci.sh
+++ b/examples/tutorial/auto_parallel/test_ci.sh
@@ -1,6 +1,8 @@
#!/bin/bash
set -euxo pipefail
-pip install -r requirements.txt
-conda install -c conda-forge coin-or-cbc
-colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
+echo "this test is outdated"
+
+# pip install -r requirements.txt
+# conda install -c conda-forge coin-or-cbc
+# colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py
index fe9abf2f1955..287f62aa7a90 100644
--- a/examples/tutorial/hybrid_parallel/config.py
+++ b/examples/tutorial/hybrid_parallel/config.py
@@ -1,4 +1,4 @@
-from colossalai.amp import AMP_TYPE
+from colossalai.legacy.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py
index 12cdec902400..21a568168e33 100644
--- a/examples/tutorial/hybrid_parallel/train.py
+++ b/examples/tutorial/hybrid_parallel/train.py
@@ -5,12 +5,12 @@
from tqdm import tqdm
import colossalai
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn import CrossEntropyLoss
+from colossalai.legacy.pipeline.pipelinable import PipelinableContext
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.utils import is_using_pp
diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py
index 2efa0ffd0556..c6d9f94505f1 100644
--- a/examples/tutorial/large_batch_optimizer/config.py
+++ b/examples/tutorial/large_batch_optimizer/config.py
@@ -1,4 +1,4 @@
-from colossalai.amp import AMP_TYPE
+from colossalai.legacy.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
diff --git a/examples/tutorial/large_batch_optimizer/test_ci.sh b/examples/tutorial/large_batch_optimizer/test_ci.sh
index 89f426c542b1..f4393938220d 100644
--- a/examples/tutorial/large_batch_optimizer/test_ci.sh
+++ b/examples/tutorial/large_batch_optimizer/test_ci.sh
@@ -1,8 +1,9 @@
#!/bin/bash
set -euxo pipefail
+echo "this test is outdated"
-pip install -r requirements.txt
+# pip install -r requirements.txt
# run test
-colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars
-colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb
+# colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars
+# colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb
diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py
index 35e54582f494..6ebd8d68083d 100644
--- a/examples/tutorial/large_batch_optimizer/train.py
+++ b/examples/tutorial/large_batch_optimizer/train.py
@@ -4,7 +4,7 @@
from tqdm import tqdm
import colossalai
-from colossalai.core import global_context as gpc
+from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import Lamb, Lars
diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py
index 7c2c152450c5..8fbed6e83d52 100644
--- a/examples/tutorial/opt/opt/colossalai_zero.py
+++ b/examples/tutorial/opt/opt/colossalai_zero.py
@@ -2,7 +2,7 @@
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
- from colossalai.zero.legacy import TensorShardStrategy
+ from colossalai.legacy.zero import TensorShardStrategy
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",
diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py
index 95f0abf1d8c9..dfcd3b382d3c 100644
--- a/examples/tutorial/opt/opt/context.py
+++ b/examples/tutorial/opt/opt/context.py
@@ -1,7 +1,7 @@
import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
class barrier_context():
diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py
index 91380e243fb8..8cbf3d2a2850 100755
--- a/examples/tutorial/opt/opt/run_clm.py
+++ b/examples/tutorial/opt/opt/run_clm.py
@@ -51,12 +51,13 @@
from transformers.utils.versions import require_version
import colossalai
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.tensor import ProcessGroup
+from colossalai.legacy.utils import get_dataloader
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
-from colossalai.tensor import ProcessGroup
-from colossalai.utils import get_current_device, get_dataloader
+from colossalai.utils import get_current_device
from colossalai.zero import GeminiOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
diff --git a/examples/tutorial/opt/opt/test_ci.sh b/examples/tutorial/opt/opt/test_ci.sh
index 431b37c12004..9cbc49c7b001 100755
--- a/examples/tutorial/opt/opt/test_ci.sh
+++ b/examples/tutorial/opt/opt/test_ci.sh
@@ -1,21 +1,21 @@
#!/bin/bash
set -xue
+echo "this test is outdated"
+# pip install -r requirements.txt
-pip install -r requirements.txt
+# BS=4
+# MEMCAP=0
+# GPUNUM=4
+# MODLE="facebook/opt-125m"
-BS=4
-MEMCAP=0
-GPUNUM=4
-MODLE="facebook/opt-125m"
-
-torchrun \
- --nproc_per_node ${GPUNUM} \
- --master_port 19198 \
- run_clm.py \
- -s \
- --output_dir $PWD \
- --mem_cap ${MEMCAP} \
- --model_name_or_path ${MODLE} \
- --per_device_train_batch_size ${BS} \
- --num_train_epochs 1
+# torchrun \
+# --nproc_per_node ${GPUNUM} \
+# --master_port 19198 \
+# run_clm.py \
+# -s \
+# --output_dir $PWD \
+# --mem_cap ${MEMCAP} \
+# --model_name_or_path ${MODLE} \
+# --per_device_train_batch_size ${BS} \
+# --num_train_epochs 1
diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py
index 6edf9cc2c7e5..887de7164e12 100644
--- a/examples/tutorial/sequence_parallel/config.py
+++ b/examples/tutorial/sequence_parallel/config.py
@@ -1,4 +1,4 @@
-from colossalai.amp import AMP_TYPE
+from colossalai.legacy.amp import AMP_TYPE
# hyper-parameters
TRAIN_ITERS = 10
diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py
index 1ef2d999389f..6fdf07ba5b69 100644
--- a/examples/tutorial/sequence_parallel/data/__init__.py
+++ b/examples/tutorial/sequence_parallel/data/__init__.py
@@ -1,10 +1,12 @@
-from colossalai.context.parallel_context import ParallelContext
-from colossalai.core import global_context as gpc
+import torch
+
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.context.parallel_context import ParallelContext
+from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
-from colossalai.context import ParallelMode
-from .datasets.data_samplers import build_pretraining_data_loader
+
from .datasets.builder import build_train_valid_test_datasets
-import torch
+from .datasets.data_samplers import build_pretraining_data_loader
def cyclic_iter(iter):
@@ -18,8 +20,7 @@ def build_train_valid_test_data_iterators(train_iters,
eval_interval,
eval_iters,
dataloader_type='single',
- **kwargs
- ):
+ **kwargs):
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
logger = get_dist_logger()
@@ -42,9 +43,7 @@ def build_train_valid_test_data_iterators(train_iters,
train_samples = train_iters * global_batch_size
eval_iters_ = (train_iters // eval_interval + 1) * eval_iters
test_iters = eval_iters
- train_val_test_num_samples = [train_samples,
- eval_iters_ * global_batch_size,
- test_iters * global_batch_size]
+ train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size]
logger.info(' > datasets target sizes (minimum size):')
logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0])
logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0])
@@ -56,19 +55,20 @@ def build_train_valid_test_data_iterators(train_iters,
# Build dataloaders.
dp_size = gpc.get_world_size(ParallelMode.DATA)
- train_dataloader = build_pretraining_data_loader(
- train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size)
- valid_dataloader = build_pretraining_data_loader(
- valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size)
- test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size)
+ train_dataloader = build_pretraining_data_loader(train_ds,
+ consumed_samples=0,
+ micro_batch_size=global_batch_size // dp_size)
+ valid_dataloader = build_pretraining_data_loader(valid_ds,
+ consumed_samples=0,
+ micro_batch_size=global_batch_size // dp_size)
+ test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and train_iters > 0
do_valid = valid_dataloader is not None and eval_iters > 0
do_test = test_dataloader is not None and eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
- flags = torch.cuda.LongTensor(
- [int(do_train), int(do_valid), int(do_test)])
+ flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py
index d092db3e7dd8..b65ca1e64f3c 100644
--- a/examples/tutorial/sequence_parallel/data/bert_helper.py
+++ b/examples/tutorial/sequence_parallel/data/bert_helper.py
@@ -1,7 +1,8 @@
-from colossalai.core import global_context as gpc
-from colossalai.context import ParallelMode
import torch
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+
_MAX_DATA_DIM = 5
@@ -22,7 +23,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
- torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
+ torch.distributed.broadcast(sizes_cuda,
+ gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
# Move back to cpu and unpack.
@@ -60,19 +62,15 @@ def broadcast_data(keys, data, datatype):
"""
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
- key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
- data)
+ key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
# Pack on rank zero.
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# Check that all keys have the same data type.
# Flatten the data associated with the keys
- flatten_data = torch.cat(
- [data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
+ flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
else:
- flatten_data = torch.empty(total_numel,
- device=torch.cuda.current_device(),
- dtype=datatype)
+ flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
# Broadcast
torch.distributed.broadcast(flatten_data,
@@ -139,7 +137,7 @@ def get_batch_for_sequence_parallel(data_iterator):
seq_length = data_b['text'].size(1)
sub_seq_length = seq_length // local_world_size
sub_seq_start = local_rank * sub_seq_length
- sub_seq_end = (local_rank+1) * sub_seq_length
+ sub_seq_end = (local_rank + 1) * sub_seq_length
#
# # Unpack.
tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
@@ -156,10 +154,9 @@ class SequenceParallelDataIterator:
def __init__(self, data_iter):
self.data_iter = data_iter
-
def __iter__(self):
return self.data_iter
def __next__(self):
- return get_batch_for_sequence_parallel(self.data_iter)
\ No newline at end of file
+ return get_batch_for_sequence_parallel(self.data_iter)
diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py
index d6388bd9f8e4..70c1269122dc 100644
--- a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py
+++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py
@@ -21,8 +21,8 @@
import torch
from torch.utils.data import Dataset
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from ..tokenizer import get_tokenizer
diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py
index cf547ad97558..b9c197c95ae3 100644
--- a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py
+++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py
@@ -14,10 +14,12 @@
# limitations under the License.
"""Dataloaders."""
-import torch
import random
-from colossalai.core import global_context as gpc
-from colossalai.context import ParallelMode
+
+import torch
+
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
index ee3c923e8e76..ba832b5cdce9 100644
--- a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
+++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
@@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Megatron tokenizers."""
-from abc import ABC
-from abc import abstractmethod
-from colossalai.core import global_context as gpc
-from colossalai.context import ParallelMode
+from abc import ABC, abstractmethod
+
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from .bert_tokenization import FullTokenizer as FullBertTokenizer
@@ -26,18 +25,13 @@
def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
"""Initialize tokenizer."""
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
- print('> building {} tokenizer ...'.format(tokenizer_type),
- flush=True)
+ print('> building {} tokenizer ...'.format(tokenizer_type), flush=True)
# Select and instantiate the tokenizer.
if tokenizer_type == 'BertWordPieceLowerCase':
- tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
- lower_case=True,
- vocab_extra_ids=vocab_extra_ids)
+ tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids)
elif tokenizer_type == 'BertWordPieceCase':
- tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
- lower_case=False,
- vocab_extra_ids=vocab_extra_ids)
+ tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(tokenizer_type))
@@ -62,8 +56,8 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128):
after += 1
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
- '(new size: {})'.format(
- orig_vocab_size, after - orig_vocab_size, after), flush=True)
+ '(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after),
+ flush=True)
return after
@@ -142,8 +136,7 @@ def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
- SPECIAL_TOKENS = {'eos_token': '[EOS]',
- 'bos_token': '[BOS]'}
+ SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
@@ -155,8 +148,7 @@ def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
# (dsachan) Add additional special tokens
# These can be used as sentinel tokens in T5 model inputs
additional_special_tokens = []
- additional_special_tokens.extend(
- ["".format(i) for i in range(vocab_extra_ids)])
+ additional_special_tokens.extend(["".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)
def add_token(self, token):
diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py
index e87a778cf5d5..b3f2487a438b 100644
--- a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py
+++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py
@@ -1,37 +1,29 @@
import torch
+import torch.distributed as dist
import torch.nn as nn
-from colossalai.core import global_context as gpc
-from colossalai.context import ParallelMode
-from colossalai.logging import get_dist_logger
import torch.nn.functional as F
-import torch.distributed as dist
+
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+
from .cross_entropy import vocab_cross_entropy
class BertLoss(nn.Module):
- def forward(self,
- lm_loss,
- sop_logits,
- loss_mask,
- sentence_order):
+ def forward(self, lm_loss, sop_logits, loss_mask, sentence_order):
lm_loss_ = lm_loss.float()
loss_mask = loss_mask.float()
loss_mask_sum = loss_mask.sum()
- lm_loss = torch.sum(
- lm_loss_.view(-1) * loss_mask.reshape(-1))
+ lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1))
lm_loss /= loss_mask_sum
- torch.distributed.all_reduce(
- lm_loss,
- group=gpc.get_group(ParallelMode.SEQUENCE)
- )
+ torch.distributed.all_reduce(lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE))
if sop_logits is not None:
- sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
- sentence_order.view(-1),
- ignore_index=-1)
+ sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE)
else:
diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py
index 54553c29a61f..ed15c6ea8054 100644
--- a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py
+++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py
@@ -1,7 +1,8 @@
-from colossalai.context.parallel_mode import ParallelMode
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
+from colossalai.legacy.context.parallel_mode import ParallelMode
+
class _VocabCrossEntropy(torch.autograd.Function):
@@ -24,8 +25,7 @@ def forward(ctx, vocab_parallel_logits, target):
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1))
masked_target_1d = masked_target.view(-1)
- arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
- device=logits_2d.device)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
@@ -58,10 +58,8 @@ def backward(ctx, grad_output):
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
- arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
- device=grad_2d.device)
- grad_2d[arange_1d, masked_target_1d] -= (
- 1.0 - target_mask.view(-1).float())
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+ grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py
index b8adb501f95e..4ba64bbe2b9f 100644
--- a/examples/tutorial/sequence_parallel/model/bert.py
+++ b/examples/tutorial/sequence_parallel/model/bert.py
@@ -3,13 +3,13 @@
import torch
import torch.nn as nn
-from colossalai.context import ParallelMode
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.kernel import LayerNorm
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.legacy.pipeline.utils import partition_uniform
from colossalai.logging import get_dist_logger
-from colossalai.pipeline.utils import partition_uniform
from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding
from .layers.init_method import init_normal, output_init_normal
diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py
index ea336b9d131e..9e25157e1b40 100644
--- a/examples/tutorial/sequence_parallel/model/layers/head.py
+++ b/examples/tutorial/sequence_parallel/model/layers/head.py
@@ -1,15 +1,17 @@
-import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
-from .pooler import Pooler
-from .linear import Linear
-from .embedding import VocabEmbedding
-from colossalai.core import global_context as gpc
-from colossalai.context import ParallelMode
-from colossalai.kernel import LayerNorm
from loss_func.cross_entropy import vocab_cross_entropy
+import colossalai
+from colossalai.kernel import LayerNorm
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+
+from .embedding import VocabEmbedding
+from .linear import Linear
+from .pooler import Pooler
+
class BertLMHead(nn.Module):
"""Masked LM head for Bert
@@ -19,10 +21,11 @@ class BertLMHead(nn.Module):
layernorm_epsilon: tolerance for layer norm divisions
"""
- def __init__(self,
- vocab_size,
- hidden_size,
- ):
+ def __init__(
+ self,
+ vocab_size,
+ hidden_size,
+ ):
super(BertLMHead, self).__init__()
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py
index 53a326ddacf1..dd66bfe13585 100644
--- a/examples/tutorial/sequence_parallel/model/layers/preprocess.py
+++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py
@@ -1,7 +1,8 @@
-from colossalai.context.parallel_mode import ParallelMode
import torch
import torch.nn as nn
-from colossalai.core import global_context as gpc
+
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
class PreProcessor(nn.Module):
@@ -14,8 +15,8 @@ def bert_position_ids(self, token_ids):
# Create position ids
seq_length = token_ids.size(1)
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
- position_ids = torch.arange(seq_length*local_rank,
- seq_length * (local_rank+1),
+ position_ids = torch.arange(seq_length * local_rank,
+ seq_length * (local_rank + 1),
dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
diff --git a/examples/tutorial/sequence_parallel/test_ci.sh b/examples/tutorial/sequence_parallel/test_ci.sh
index 7bc20de3b6e4..1cd646526d99 100644
--- a/examples/tutorial/sequence_parallel/test_ci.sh
+++ b/examples/tutorial/sequence_parallel/test_ci.sh
@@ -1,7 +1,8 @@
#!/bin/bash
set -euxo pipefail
-pip install -r requirements.txt
+echo "this test is outdated"
+# pip install -r requirements.txt
# run test
-colossalai run --nproc_per_node 4 train.py
+# colossalai run --nproc_per_node 4 train.py
diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py
index 86c4edeb5550..b8b89cda5525 100644
--- a/examples/tutorial/sequence_parallel/train.py
+++ b/examples/tutorial/sequence_parallel/train.py
@@ -8,14 +8,15 @@
from model.bert import BertForPretrain, build_pipeline_bert
import colossalai
-from colossalai.amp import AMP_TYPE
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.kernel import LayerNorm
+from colossalai.legacy.amp import AMP_TYPE
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine.schedule import PipelineSchedule
+from colossalai.legacy.utils import is_using_pp
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import FusedAdam
-from colossalai.utils import MultiTimer, is_using_pp
+from colossalai.utils import MultiTimer
def process_batch_data(batch_data):
diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py
index 193832ebc12d..df01e4c4847e 100644
--- a/tests/components_to_test/resnet.py
+++ b/tests/components_to_test/resnet.py
@@ -1,11 +1,14 @@
-from torchvision.models import resnet18
-from .registry import non_distributed_component_funcs
-from pathlib import Path
import os
+from pathlib import Path
+
import torch
-from torchvision.transforms import transforms
from torchvision.datasets import CIFAR10
-from colossalai.utils import get_dataloader
+from torchvision.models import resnet18
+from torchvision.transforms import transforms
+
+from colossalai.legacy.utils import get_dataloader
+
+from .registry import non_distributed_component_funcs
def get_cifar10_dataloader(train):
diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
index f184f64b35d0..b65e6d0d8863 100644
--- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
+++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
@@ -6,12 +6,12 @@
import torchvision.models as tm
import colossalai
-from colossalai.core import global_context as gpc
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
# from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.legacy.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
index db268b91d0a0..babdddfada18 100644
--- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
+++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
@@ -8,12 +8,12 @@
from torch.fx import GraphModule
import colossalai
-from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.legacy.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
index 4e3c26c1ba9c..715f62358e2d 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
@@ -13,10 +13,9 @@
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam
-from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.utils import get_current_device
-from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module):
@@ -70,14 +69,12 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
print(strategy)
print('=' * msg_length)
- dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2)
gemini_config = dict(strict_ddp_mode=False,
device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
search_range_m=128)
- post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py
index 15610e2b50dc..593658fd1368 100644
--- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py
+++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py
@@ -6,9 +6,9 @@
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.autochunk.utils import flat_list
-from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.legacy.core import global_context as gpc
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py
index b6a792f5652c..264331a5fef0 100644
--- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py
+++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py
@@ -5,9 +5,9 @@
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
-from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.legacy.core import global_context as gpc
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py
index 3202318fb6d1..65d1e9c4d090 100644
--- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py
+++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py
@@ -5,9 +5,9 @@
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
-from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.legacy.core import global_context as gpc
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py
index 13b7119424e4..2304203d1e04 100644
--- a/tests/test_cluster/test_process_group_mesh.py
+++ b/tests/test_cluster/test_process_group_mesh.py
@@ -7,8 +7,8 @@
def check_process_group_mesh_with_gpc():
- from colossalai.context import ParallelMode
- from colossalai.core import global_context as gpc
+ from colossalai.legacy.context import ParallelMode
+ from colossalai.legacy.core import global_context as gpc
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, 2, 2)
@@ -138,7 +138,7 @@ def run_dist(rank, world_size, port):
port=port,
host='localhost')
# TODO(ver217): this function should be removed when gpc is removed
- check_process_group_mesh_with_gpc()
+ # check_process_group_mesh_with_gpc()
check_process_group_mesh_with_cases()
diff --git a/tests/test_context/configs/parallel_2d_init.py b/tests/test_context/configs/parallel_2d_init.py
deleted file mode 100644
index 6af884450ad0..000000000000
--- a/tests/test_context/configs/parallel_2d_init.py
+++ /dev/null
@@ -1,10 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-parallel = dict(
- pipeline=dict(size=2),
- tensor=dict(
- size=4,
- mode='2d'
- )
-)
diff --git a/tests/test_context/configs/parallel_2p5d_init.py b/tests/test_context/configs/parallel_2p5d_init.py
deleted file mode 100644
index c2d896d383e2..000000000000
--- a/tests/test_context/configs/parallel_2p5d_init.py
+++ /dev/null
@@ -1,11 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-parallel = dict(
- pipeline=dict(size=2),
- tensor=dict(
- size=8,
- depth=2,
- mode='2.5d'
- )
-)
diff --git a/tests/test_context/configs/parallel_3d_init.py b/tests/test_context/configs/parallel_3d_init.py
deleted file mode 100644
index 0ec724f8bb4f..000000000000
--- a/tests/test_context/configs/parallel_3d_init.py
+++ /dev/null
@@ -1,10 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-parallel = dict(
- pipeline=dict(size=2),
- tensor=dict(
- size=8,
- mode='3d'
- )
-)
diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py
index 7c6339eff67e..c18bf56752fb 100644
--- a/tests/test_device/test_init_logical_pg.py
+++ b/tests/test_device/test_init_logical_pg.py
@@ -3,7 +3,6 @@
import torch.distributed as dist
from torch.distributed import ReduceOp
-from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -13,7 +12,7 @@ def check_layer(rank, world_size, port):
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
- assert rank == gpc.get_global_rank()
+ assert rank == dist.get_rank()
tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()
mesh_shape = (2, 2)
@@ -27,8 +26,6 @@ def check_layer(rank, world_size, port):
dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
assert tensor.equal(tensor_to_check)
- gpc.destroy()
-
@pytest.mark.dist
@rerun_if_address_is_in_use()
diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
index bcac2ec426d9..6a12f5bc848e 100644
--- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
+++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
@@ -4,9 +4,9 @@
from torch.utils.checkpoint import checkpoint
import colossalai
-from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.legacy.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn
try:
diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py
index 5b327807a57b..ebcfb4d7b633 100644
--- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py
+++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py
@@ -2,9 +2,9 @@
import torch
import colossalai
-from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.legacy.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn
try:
diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py
index c217b96586fe..dac59c23655e 100644
--- a/tests/test_fx/test_codegen/test_offload_codegen.py
+++ b/tests/test_fx/test_codegen/test_offload_codegen.py
@@ -5,9 +5,9 @@
from torch.fx import GraphModule
import colossalai
-from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.legacy.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use, spawn
try:
diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py
index 1044be7db1f4..29135b45f997 100644
--- a/tests/test_fx/test_parallel_1d.py
+++ b/tests/test_fx/test_parallel_1d.py
@@ -5,9 +5,9 @@
import torch
from torch.fx import symbolic_trace
-from colossalai.core import global_context as gpc
from colossalai.fx.passes import column_shard_linear_pass
from colossalai.initialize import launch
+from colossalai.legacy.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py
index 55dd65201acd..db6cadfc544c 100644
--- a/tests/test_fx/test_pipeline/test_topo/topo_utils.py
+++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py
@@ -1,18 +1,22 @@
+import random
+
+import numpy as np
import torch
from torch.fx import GraphModule
-from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
+
from colossalai.fx import ColoTracer
-from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
-from colossalai.pipeline.middleware.adaptor import get_fx_topology
-import random
-import numpy as np
+from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
+from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
+from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
+
class MLP(torch.nn.Module):
+
def __init__(self, config={}):
super().__init__()
dim = config['dim']
@@ -27,6 +31,7 @@ def forward(self, x):
x = layer(x)
return x
+
def split_model_and_get_DAG(model, data_gen):
model.eval()
@@ -46,7 +51,7 @@ def split_model_and_get_DAG(model, data_gen):
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
-
+
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
@@ -54,6 +59,7 @@ def split_model_and_get_DAG(model, data_gen):
return top_module, split_submodules[0]._topo
+
def check_input(top_module, input_partition: Partition):
partition_output = input_partition.get_output_vals()
arg_pos = 0
@@ -63,13 +69,14 @@ def check_input(top_module, input_partition: Partition):
to_partition_and_offset = cur_checkee.get()
assert len(to_partition_and_offset) == len(node.users.keys())
arg_pos += 1
-
+
assert arg_pos == len(partition_output)
-
+
+
def check_submod(top_module, part_id, mid_partition: Partition):
partition_input = mid_partition.get_input_vals()
partition_output = mid_partition.get_output_vals()
-
+
cnt = 1
cur_node = None
for node in top_module.graph.nodes:
@@ -78,15 +85,15 @@ def check_submod(top_module, part_id, mid_partition: Partition):
if cnt == part_id:
cur_node = node
break
-
+
assert len(partition_input) == len(cur_node.args)
assert len(partition_output) == len(cur_node.users)
-def check_topo(top_module, topo: Topo):
+
+def check_topo(top_module, topo: Topo):
input_partition = topo.get_input_partition()
mid_partitions = topo.get_mid_partitions()
-
+
check_input(top_module, input_partition)
for part_id, submod in mid_partitions.items():
check_submod(top_module, part_id, submod)
-
\ No newline at end of file
diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py
similarity index 94%
rename from tests/test_amp/test_naive_fp16.py
rename to tests/test_legacy/test_amp/test_naive_fp16.py
index 6ce4c7f49725..54bf6498549c 100644
--- a/tests/test_amp/test_naive_fp16.py
+++ b/tests/test_legacy/test_amp/test_naive_fp16.py
@@ -4,7 +4,7 @@
import torch
import colossalai
-from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
+from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -78,7 +78,7 @@ def run_naive_amp():
def run_dist(rank, world_size, port):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+ colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
run_naive_amp()
diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py
similarity index 95%
rename from tests/test_amp/test_torch_fp16.py
rename to tests/test_legacy/test_amp/test_torch_fp16.py
index 6451aa6264a3..89810b5d0351 100644
--- a/tests/test_amp/test_torch_fp16.py
+++ b/tests/test_legacy/test_amp/test_torch_fp16.py
@@ -4,7 +4,7 @@
import torch
import colossalai
-from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
+from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -78,7 +78,7 @@ def run_torch_amp():
def run_dist(rank, world_size, port):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+ colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
run_torch_amp()
diff --git a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
index c5fb049fe93f..4851b3e36bbc 100644
--- a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
+++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
@@ -1,10 +1,10 @@
import pytest
import torch
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
from colossalai.legacy.communication.p2p_v2 import _recv_object, _send_object
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py
index 3251d8d46f0b..fccfcd973000 100644
--- a/tests/test_legacy/test_comm/test_comm.py
+++ b/tests/test_legacy/test_comm/test_comm.py
@@ -2,10 +2,10 @@
import torch
import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
diff --git a/tests/test_legacy/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py
index f50982ee1c2d..a1322e6f28db 100644
--- a/tests/test_legacy/test_comm/test_object_list_p2p.py
+++ b/tests/test_legacy/test_comm/test_object_list_p2p.py
@@ -1,9 +1,6 @@
import pytest
import torch
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
from colossalai.legacy.communication.p2p import (
recv_backward,
recv_forward,
@@ -12,6 +9,9 @@
send_forward,
send_forward_recv_backward,
)
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(pipeline=2))
diff --git a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
index 040c63322f2b..f805bd19d7e8 100644
--- a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
+++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
@@ -1,10 +1,10 @@
import pytest
import torch
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
from colossalai.legacy.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_legacy/test_context/configs/parallel_2d_init.py b/tests/test_legacy/test_context/configs/parallel_2d_init.py
new file mode 100644
index 000000000000..6cf816942fdd
--- /dev/null
+++ b/tests/test_legacy/test_context/configs/parallel_2d_init.py
@@ -0,0 +1,4 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode='2d'))
diff --git a/tests/test_legacy/test_context/configs/parallel_2p5d_init.py b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py
new file mode 100644
index 000000000000..b946d45b3a91
--- /dev/null
+++ b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py
@@ -0,0 +1,4 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode='2.5d'))
diff --git a/tests/test_legacy/test_context/configs/parallel_3d_init.py b/tests/test_legacy/test_context/configs/parallel_3d_init.py
new file mode 100644
index 000000000000..a1564bbb2d51
--- /dev/null
+++ b/tests/test_legacy/test_context/configs/parallel_3d_init.py
@@ -0,0 +1,4 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode='3d'))
diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_legacy/test_context/test_hybrid_parallel.py
similarity index 95%
rename from tests/test_context/test_hybrid_parallel.py
rename to tests/test_legacy/test_context/test_hybrid_parallel.py
index d25668afd430..05cd1d294dcd 100644
--- a/tests/test_context/test_hybrid_parallel.py
+++ b/tests/test_legacy/test_context/test_hybrid_parallel.py
@@ -6,11 +6,11 @@
import pytest
import torch
-from colossalai import launch
-from colossalai.context import reset_seeds
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as tp_env
+from colossalai.legacy import launch
+from colossalai.legacy.context import reset_seeds
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as tp_env
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_legacy/test_data/test_cifar10_dataset.py
similarity index 100%
rename from tests/test_data/test_cifar10_dataset.py
rename to tests/test_legacy/test_data/test_cifar10_dataset.py
diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_legacy/test_data/test_data_parallel_sampler.py
similarity index 87%
rename from tests/test_data/test_data_parallel_sampler.py
rename to tests/test_legacy/test_data/test_data_parallel_sampler.py
index 7beef707c096..cf10fe9dfa3c 100644
--- a/tests/test_data/test_data_parallel_sampler.py
+++ b/tests/test_legacy/test_data/test_data_parallel_sampler.py
@@ -10,10 +10,11 @@
from torchvision import datasets, transforms
import colossalai
-from colossalai.context import Config, ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.context import Config
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.utils import get_dataloader
from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_dataloader
CONFIG = Config(dict(
parallel=dict(
@@ -26,7 +27,7 @@
def run_data_sampler(rank, world_size, port):
dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
- colossalai.launch(**dist_args)
+ colossalai.legacy.launch(**dist_args)
print('finished initialization')
# build dataset
diff --git a/tests/test_legacy/test_data/test_deterministic_dataloader.py b/tests/test_legacy/test_data/test_deterministic_dataloader.py
new file mode 100644
index 000000000000..421b8d255318
--- /dev/null
+++ b/tests/test_legacy/test_data/test_deterministic_dataloader.py
@@ -0,0 +1,74 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import os
+from pathlib import Path
+
+import pytest
+import torch
+import torch.distributed as dist
+from torchvision import datasets, transforms
+
+import colossalai
+from colossalai.context import Config
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.utils import get_dataloader
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+
+CONFIG = Config(
+ dict(
+ train_data=dict(
+ dataset=dict(
+ type='CIFAR10',
+ root=Path(os.environ['DATA']),
+ train=True,
+ download=True,
+ ),
+ dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
+ ),
+ parallel=dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=1, mode=None),
+ ),
+ seed=1024,
+ ))
+
+
+def run_data_sampler(rank, world_size, port):
+ dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
+ colossalai.legacy.launch(**dist_args)
+
+ # build dataset
+ transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)]
+ transform_pipeline = transforms.Compose(transform_pipeline)
+ dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
+
+ # build dataloader
+ dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False)
+
+ data_iter = iter(dataloader)
+ img, label = data_iter.next()
+ img = img[0]
+
+ if gpc.get_local_rank(ParallelMode.DATA) != 0:
+ img_to_compare = img.clone()
+ else:
+ img_to_compare = img
+ dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))
+
+ if gpc.get_local_rank(ParallelMode.DATA) != 0:
+ # this is without sampler
+ # this should be false if data parallel sampler to given to the dataloader
+ assert torch.equal(img,
+ img_to_compare), 'Same image was distributed across ranks and expected it to be the same'
+ torch.cuda.empty_cache()
+
+
+@rerun_if_address_is_in_use()
+def test_data_sampler():
+ spawn(run_data_sampler, 4)
+
+
+if __name__ == '__main__':
+ test_data_sampler()
diff --git a/tests/test_legacy/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py
index 62493cf3712d..8499784038d2 100644
--- a/tests/test_legacy/test_engine/test_engine.py
+++ b/tests/test_legacy/test_engine/test_engine.py
@@ -1,8 +1,8 @@
import pytest
import colossalai
-from colossalai.amp import AMP_TYPE
-from colossalai.core import global_context as gpc
+from colossalai.legacy.amp import AMP_TYPE
+from colossalai.legacy.core import global_context as gpc
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -20,10 +20,11 @@ def run_train(model_name, amp_mode):
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False)
- engine, train_dataloader, *args = colossalai.initialize(model=model,
- optimizer=optimizer_class(model.parameters(), lr=1e-3),
- criterion=criterion,
- train_dataloader=train_dataloader)
+ engine, train_dataloader, *args = colossalai.legacy.initialize(model=model,
+ optimizer=optimizer_class(model.parameters(),
+ lr=1e-3),
+ criterion=criterion,
+ train_dataloader=train_dataloader)
try:
engine.train()
@@ -48,7 +49,12 @@ def run_train(model_name, amp_mode):
def run_engine(rank, world_size, port):
# init dist env
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config=CONFIG,
+ rank=rank,
+ world_size=world_size,
+ host='localhost',
+ port=port,
+ backend='nccl')
run_train()
diff --git a/tests/test_legacy/test_engine/test_gradient_accumluation.py b/tests/test_legacy/test_engine/test_gradient_accumluation.py
index 7783827c7c44..168c93c1a572 100644
--- a/tests/test_legacy/test_engine/test_gradient_accumluation.py
+++ b/tests/test_legacy/test_engine/test_gradient_accumluation.py
@@ -10,10 +10,10 @@
from torchvision.models import resnet18
import colossalai
-from colossalai.core import global_context as gpc
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.utils import get_dataloader
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_dataloader
# Config
BATCH_SIZE = 2
@@ -27,7 +27,12 @@
def run_no_pipeline(rank, world_size, port):
# init dist env
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config=CONFIG,
+ rank=rank,
+ world_size=world_size,
+ host='localhost',
+ port=port,
+ backend='nccl')
# build model
model = resnet18(num_classes=10)
@@ -49,10 +54,10 @@ def run_no_pipeline(rank, world_size, port):
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
- engine, train_dataloader, *args = colossalai.initialize(model=model,
- optimizer=optimizer,
- criterion=criterion,
- train_dataloader=train_dataloader)
+ engine, train_dataloader, *args = colossalai.legacy.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader)
logger = get_dist_logger()
rank = torch.distributed.get_rank()
param_track = []
diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
index dcb2be62671b..859707e6129d 100644
--- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
+++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
@@ -2,9 +2,9 @@
import torch.distributed as dist
from torch.nn import Parameter
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.nn import (
Classifier1D,
Embedding1D,
@@ -15,7 +15,8 @@
VocabParallelCrossEntropyLoss1D,
VocabParallelEmbedding1D,
)
-from colossalai.utils import get_current_device, print_rank_0
+from colossalai.legacy.utils import print_rank_0
+from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
diff --git a/tests/test_legacy/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py
index 891512542475..2a016ed7b33d 100644
--- a/tests/test_legacy/test_layers/test_1d/test_1d.py
+++ b/tests/test_legacy/test_layers/test_1d/test_1d.py
@@ -5,8 +5,8 @@
import torch
from checks_1d.check_layer_1d import *
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
index 0ee88c26035f..494497be33e2 100644
--- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
@@ -1,7 +1,7 @@
import torch
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn import (
Classifier2D,
CrossEntropyLoss2D,
@@ -15,7 +15,8 @@
VocabParallelCrossEntropyLoss2D,
VocabParallelEmbedding2D,
)
-from colossalai.utils import get_current_device, print_rank_0
+from colossalai.legacy.utils import print_rank_0
+from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
index ae1d1120cfb9..034dbe5ca29c 100644
--- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
@@ -3,10 +3,11 @@
import torch
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
-from colossalai.utils import get_current_device, print_rank_0
+from colossalai.legacy.utils import print_rank_0
+from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal
diff --git a/tests/test_legacy/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py
index bcea5ce7b25d..a4b46793f19d 100644
--- a/tests/test_legacy/test_layers/test_2d/test_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/test_2d.py
@@ -18,8 +18,8 @@
)
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
index 5a99b05cfe7e..e7a9a8be45d0 100644
--- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
@@ -1,8 +1,8 @@
import torch
from torch.nn import Parameter
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn import (
Classifier2p5D,
CrossEntropyLoss2p5D,
@@ -16,7 +16,8 @@
VocabParallelCrossEntropyLoss2p5D,
VocabParallelEmbedding2p5D,
)
-from colossalai.utils import get_current_device, print_rank_0
+from colossalai.legacy.utils import print_rank_0
+from colossalai.utils import get_current_device
from .common import *
diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
index db19967676d2..fe78ef669bf0 100644
--- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
@@ -1,9 +1,10 @@
import torch
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D
-from colossalai.utils import get_current_device, print_rank_0
+from colossalai.legacy.utils import print_rank_0
+from colossalai.utils import get_current_device
from .common import *
diff --git a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
index 373d834d0032..38ba3ba78575 100644
--- a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
@@ -3,8 +3,8 @@
from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
index cee639a9f00a..2a9dcc3cdc16 100644
--- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
+++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
@@ -5,8 +5,8 @@
import torch
-from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
-from colossalai.core import global_context
+from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.legacy.core import global_context
from colossalai.legacy.nn import (
Classifier3D,
CrossEntropyLoss3D,
@@ -21,8 +21,9 @@
VocabParallelEmbedding3D,
)
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+from colossalai.legacy.utils import print_rank_0
from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device, print_rank_0
+from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
diff --git a/tests/test_legacy/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py
index fde71a4a0d26..2a32d8935c00 100644
--- a/tests/test_legacy/test_layers/test_3d/test_3d.py
+++ b/tests/test_legacy/test_layers/test_3d/test_3d.py
@@ -15,8 +15,8 @@
check_vocab_parallel_loss,
)
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py
index 0760a3f1ec38..c58445a396ec 100644
--- a/tests/test_legacy/test_layers/test_cache_embedding.py
+++ b/tests/test_legacy/test_layers/test_cache_embedding.py
@@ -14,7 +14,8 @@
ParallelCachedEmbeddingBagTablewise,
TablewiseEmbeddingBagConfig,
)
-from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
+from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
+from colossalai.tensor import ColoTensor
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
NUM_EMBED, EMBED_DIM = 10, 8
@@ -359,7 +360,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# run_parallel_freq_aware_embed_columnwise(rank, world_size)
run_parallel_freq_aware_embed_tablewise(rank, world_size)
diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
index 7ff91a7b76e0..ac9493adab2e 100644
--- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
+++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
@@ -1,7 +1,7 @@
import torch
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn import TransformerSelfAttentionRing
from colossalai.utils import get_current_device
diff --git a/tests/test_legacy/test_layers/test_sequence/test_sequence.py b/tests/test_legacy/test_layers/test_sequence/test_sequence.py
index b9e6c12479ee..85226f9d934a 100644
--- a/tests/test_legacy/test_layers/test_sequence/test_sequence.py
+++ b/tests/test_legacy/test_layers/test_sequence/test_sequence.py
@@ -3,8 +3,8 @@
import torch.distributed as dist
import colossalai
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -120,7 +120,7 @@ def check_ring_av(rank, world_size):
def run_test(rank, world_size, port):
- colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port)
+ colossalai.legacy.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port)
# check_ring_qk(rank, world_size)
check_ring_av(rank, world_size)
diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_legacy/test_pipeline/rpc_test_utils.py
similarity index 97%
rename from tests/test_pipeline/rpc_test_utils.py
rename to tests/test_legacy/test_pipeline/rpc_test_utils.py
index dab474a4ee21..9a336c4224be 100644
--- a/tests/test_pipeline/rpc_test_utils.py
+++ b/tests/test_legacy/test_pipeline/rpc_test_utils.py
@@ -10,9 +10,9 @@
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
-from colossalai import launch
+from colossalai.legacy import launch
+from colossalai.legacy.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.pipeline_process_group import ppg
rpc_is_initialized = _is_current_rpc_agent_set
diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py
similarity index 94%
rename from tests/test_pipeline/test_cuda_rpc_chimera.py
rename to tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py
index 45ad8f828e61..3bff08318d40 100644
--- a/tests/test_pipeline/test_cuda_rpc_chimera.py
+++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py
@@ -1,10 +1,10 @@
import torch
-from torch import nn
import torch.autograd as autograd
+from rpc_test_utils import RpcTestModel, parse_args, rpc_run
+from torch import nn
-from colossalai.pipeline.rpc import ChimeraPipelineEngine
+from colossalai.legacy.pipeline.rpc import ChimeraPipelineEngine
from colossalai.testing import assert_close
-from rpc_test_utils import rpc_run, parse_args, RpcTestModel
# global variable for model created
feat_num = 100
diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py
similarity index 89%
rename from tests/test_pipeline/test_cuda_rpc_optimizer.py
rename to tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py
index 842566730caf..eff031ff8faa 100644
--- a/tests/test_pipeline/test_cuda_rpc_optimizer.py
+++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py
@@ -1,11 +1,10 @@
import torch
-from torch import nn
-from torch import autograd
-from torch.optim import SGD, Adam, RMSprop, Optimizer
+from rpc_test_utils import RpcTestModel, parse_args, rpc_run
+from torch import autograd, nn
+from torch.optim import SGD, Adam, Optimizer, RMSprop
-from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
+from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close
-from rpc_test_utils import rpc_run, parse_args, RpcTestModel
# global variable for model created
feat_num = 100
diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py
similarity index 87%
rename from tests/test_pipeline/test_cuda_rpc_pipeline.py
rename to tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py
index 8d03e79813e8..1a6077f8d3e9 100644
--- a/tests/test_pipeline/test_cuda_rpc_pipeline.py
+++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py
@@ -1,8 +1,8 @@
import torch
+from rpc_test_utils import RpcTestModel, parse_args, rpc_run
from torch import nn
-from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
-from rpc_test_utils import rpc_run, parse_args, RpcTestModel
+from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
# global variable for model created
feat_num = 100
diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py
similarity index 91%
rename from tests/test_pipeline/test_cuda_rpc_value_correctness.py
rename to tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py
index e6713478baec..43966ce3dbda 100644
--- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py
+++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py
@@ -1,10 +1,9 @@
import torch
-from torch import nn
-from torch import autograd
+from rpc_test_utils import RpcTestModel, parse_args, rpc_run
+from torch import autograd, nn
-from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
+from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close
-from rpc_test_utils import rpc_run, parse_args, RpcTestModel
feat_num = 100
h = 100
diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py
similarity index 94%
rename from tests/test_pipeline/test_middleware_1f1b.py
rename to tests/test_legacy/test_pipeline/test_middleware_1f1b.py
index 5b3aad703275..4e43d52f8aee 100644
--- a/tests/test_pipeline/test_middleware_1f1b.py
+++ b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py
@@ -7,13 +7,13 @@
from rpc_test_utils import DAG_MLP, MLP
from torch._C._distributed_rpc import _is_current_rpc_agent_set
-from colossalai import launch
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
+from colossalai.legacy import launch
+from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology
+from colossalai.legacy.pipeline.pipeline_process_group import ppg
+from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.middleware.adaptor import get_fx_topology
-from colossalai.pipeline.pipeline_process_group import ppg
-from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# global variable for model created
diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_legacy/test_pipeline/test_pipelinable.py
similarity index 96%
rename from tests/test_pipeline/test_pipelinable.py
rename to tests/test_legacy/test_pipeline/test_pipelinable.py
index bb016596beea..2ba5d0aa24d8 100644
--- a/tests/test_pipeline/test_pipelinable.py
+++ b/tests/test_legacy/test_pipeline/test_pipelinable.py
@@ -1,7 +1,7 @@
import pytest
import torch
-from colossalai.pipeline.pipelinable import PipelinableContext
+from colossalai.legacy.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn
NUM_CHUNKS = 1
diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py
similarity index 91%
rename from tests/test_pipeline/test_pipeline_process_group.py
rename to tests/test_legacy/test_pipeline/test_pipeline_process_group.py
index 2a00e3ac55b1..e6b95660279b 100644
--- a/tests/test_pipeline/test_pipeline_process_group.py
+++ b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py
@@ -3,9 +3,9 @@
import torch.distributed.rpc as rpc
from rpc_test_utils import pg_parse_args, rpc_is_initialized
-from colossalai.initialize import launch
+from colossalai.legacy.initialize import launch
+from colossalai.legacy.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.testing import spawn
diff --git a/tests/test_tensor/common_utils/__init__.py b/tests/test_legacy/test_tensor/common_utils/__init__.py
similarity index 95%
rename from tests/test_tensor/common_utils/__init__.py
rename to tests/test_legacy/test_tensor/common_utils/__init__.py
index 5387db70445f..9a35d02ce5ed 100644
--- a/tests/test_tensor/common_utils/__init__.py
+++ b/tests/test_legacy/test_tensor/common_utils/__init__.py
@@ -1 +1 @@
-from ._utils import *
+from ._utils import *
diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_legacy/test_tensor/common_utils/_utils.py
similarity index 93%
rename from tests/test_tensor/common_utils/_utils.py
rename to tests/test_legacy/test_tensor/common_utils/_utils.py
index b405f8cd2108..b6fea28e4c8a 100644
--- a/tests/test_tensor/common_utils/_utils.py
+++ b/tests/test_legacy/test_tensor/common_utils/_utils.py
@@ -6,9 +6,9 @@
import torch.distributed as dist
from torch.testing import assert_close
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.tensor import ComputePattern, ComputeSpec, ShardSpec
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ShardSpec
def set_seed(seed):
diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py
similarity index 91%
rename from tests/test_tensor/core/test_dist_spec_mgr.py
rename to tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py
index 89476a35b63a..b6d6bcee66ce 100644
--- a/tests/test_tensor/core/test_dist_spec_mgr.py
+++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py
@@ -5,7 +5,7 @@
import torch.distributed as dist
import colossalai
-from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec
+from colossalai.legacy.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -48,7 +48,7 @@ def check_mem():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_mem()
run()
diff --git a/tests/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py
similarity index 82%
rename from tests/test_tensor/test_parameter.py
rename to tests/test_legacy/test_tensor/test_parameter.py
index 9c3f05da1ffa..7a8694ff6789 100644
--- a/tests/test_tensor/test_parameter.py
+++ b/tests/test_legacy/test_tensor/test_parameter.py
@@ -3,13 +3,13 @@
from common_utils import tensor_equal
import colossalai
-from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
+from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.testing import free_port
@pytest.mark.skip
def test_multiinheritance():
- colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
+ colossalai.legacy.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
colo_param = ColoParameter(None, requires_grad=True)
assert colo_param.dist_spec.placement.value == 'r'
assert isinstance(colo_param, ColoTensor)
diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
index 5fb678525bb3..84652093a9fd 100644
--- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
+++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
@@ -5,9 +5,6 @@
import torch
import torch.distributed as dist
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
from colossalai.legacy.communication import (
recv_backward,
recv_forward,
@@ -18,6 +15,9 @@
send_forward_recv_backward,
send_obj_meta,
)
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
index 6d7bf6b3d89f..fd94c279b6fb 100644
--- a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
+++ b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
@@ -11,11 +11,11 @@
from torchvision.models import resnet18
import colossalai
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
+from colossalai.legacy.utils import get_dataloader, print_rank_0
from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_dataloader, print_rank_0
BATCH_SIZE = 8
@@ -63,7 +63,7 @@ def forward(self, x):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
# initialize
- engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
+ engine, train_dataloader, _, _ = colossalai.legacy.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule
schedule = engine.schedule
diff --git a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
index dab0e53a4c32..4a240533474c 100644
--- a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
+++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
@@ -2,7 +2,7 @@
import torch
import colossalai
-from colossalai.amp.amp_type import AMP_TYPE
+from colossalai.legacy.amp.amp_type import AMP_TYPE
from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@@ -22,10 +22,10 @@ def run_trainer(model_name):
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
optimizer = optimizer_class(model.parameters(), lr=1e-3)
- engine, train_dataloader, *_ = colossalai.initialize(model=model,
- optimizer=optimizer,
- criterion=criterion,
- train_dataloader=train_dataloader)
+ engine, train_dataloader, *_ = colossalai.legacy.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
@@ -45,7 +45,12 @@ def run_trainer(model_name):
def run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config=CONFIG,
+ rank=rank,
+ world_size=world_size,
+ host='localhost',
+ port=port,
+ backend='nccl')
@pytest.mark.dist
diff --git a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
index 7dfbec854ccc..521b2f32f22d 100644
--- a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
+++ b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
@@ -10,12 +10,13 @@
from torchvision.models import resnet18
import colossalai
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.trainer import Trainer
+from colossalai.legacy.utils import get_dataloader
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import MultiTimer, get_dataloader
+from colossalai.utils import MultiTimer
BATCH_SIZE = 4
IMG_SIZE = 32
@@ -28,7 +29,12 @@
def run_trainer_with_pipeline(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config=CONFIG,
+ rank=rank,
+ world_size=world_size,
+ host='localhost',
+ port=port,
+ backend='nccl')
# build model
model = resnet18(num_classes=10)
@@ -63,10 +69,10 @@ def forward(self, x):
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
- engine, train_dataloader, *args = colossalai.initialize(model=model,
- optimizer=optimizer,
- criterion=criterion,
- train_dataloader=train_dataloader)
+ engine, train_dataloader, *args = colossalai.legacy.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_legacy/test_utils/test_activation_checkpointing.py
similarity index 94%
rename from tests/test_utils/test_activation_checkpointing.py
rename to tests/test_legacy/test_utils/test_activation_checkpointing.py
index b7764c2f4371..19984ae120b5 100644
--- a/tests/test_utils/test_activation_checkpointing.py
+++ b/tests/test_legacy/test_utils/test_activation_checkpointing.py
@@ -5,10 +5,10 @@
import torch
import torch.nn.functional as F
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.context.random import add_seed, reset_seeds, seed, set_mode
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.context.random import add_seed, reset_seeds, seed, set_mode
+from colossalai.legacy.utils.activation_checkpoint import checkpoint
from colossalai.testing import clear_cache_before_run, parameterize
-from colossalai.utils.activation_checkpoint import checkpoint
def forward(x, weight):
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py
similarity index 83%
rename from tests/test_utils/test_checkpoint/test_checkpoint_1d.py
rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py
index 9c3a7e2161d2..88cd89a217fe 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py
@@ -8,17 +8,17 @@
import torch.nn as nn
import colossalai.legacy.nn as col_nn
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
+from colossalai.legacy.utils import is_using_pp
+from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
-from colossalai.utils import is_using_pp
-from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
- from colossalai.pipeline.utils import partition_uniform
+ from colossalai.legacy.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py
similarity index 83%
rename from tests/test_utils/test_checkpoint/test_checkpoint_2d.py
rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py
index 03b2e4f2a9b2..591cd714fc65 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py
@@ -8,17 +8,17 @@
import torch.nn as nn
import colossalai.legacy.nn as col_nn
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
+from colossalai.legacy.utils import is_using_pp
+from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
-from colossalai.utils import is_using_pp
-from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
- from colossalai.pipeline.utils import partition_uniform
+ from colossalai.legacy.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py
similarity index 84%
rename from tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py
index cafffd0a6202..b165b4276f10 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py
@@ -8,17 +8,17 @@
import torch.nn as nn
import colossalai.legacy.nn as col_nn
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
+from colossalai.legacy.utils import is_using_pp
+from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
-from colossalai.utils import is_using_pp
-from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
- from colossalai.pipeline.utils import partition_uniform
+ from colossalai.legacy.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py
similarity index 83%
rename from tests/test_utils/test_checkpoint/test_checkpoint_3d.py
rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py
index 9b43be9e8cc5..2ce054d33b2d 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py
@@ -8,17 +8,17 @@
import torch.nn as nn
import colossalai.legacy.nn as col_nn
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
+from colossalai.legacy.context.parallel_mode import ParallelMode
+from colossalai.legacy.core import global_context as gpc
+from colossalai.legacy.initialize import launch
+from colossalai.legacy.utils import is_using_pp
+from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
-from colossalai.utils import is_using_pp
-from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
- from colossalai.pipeline.utils import partition_uniform
+ from colossalai.legacy.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
diff --git a/tests/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py
similarity index 76%
rename from tests/test_utils/test_memory.py
rename to tests/test_legacy/test_utils/test_memory.py
index c88c2f8ec3c5..2e25dc773b68 100644
--- a/tests/test_utils/test_memory.py
+++ b/tests/test_legacy/test_utils/test_memory.py
@@ -1,9 +1,9 @@
import pytest
import colossalai
+from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device
-from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
@@ -14,7 +14,7 @@ def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
similarity index 91%
rename from tests/test_utils/test_norm_gradient_clipping.py
rename to tests/test_legacy/test_utils/test_norm_gradient_clipping.py
index 4fd7c3c60a95..918f174aba76 100644
--- a/tests/test_utils/test_norm_gradient_clipping.py
+++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
@@ -4,12 +4,12 @@
from torch.nn.utils import clip_grad_norm_
import colossalai
+from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec
+from colossalai.legacy.utils.common import clip_grad_norm
from colossalai.logging import disable_existing_loggers
-from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
-from colossalai.utils.common import clip_grad_norm
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
@@ -62,7 +62,7 @@ def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_ty
def run_dist(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_grad_clip_norm(world_size=world_size)
diff --git a/tests/test_utils/test_commons.py b/tests/test_legacy/test_zero/test_commons.py
similarity index 82%
rename from tests/test_utils/test_commons.py
rename to tests/test_legacy/test_zero/test_commons.py
index 2633d7da21aa..42a9f1eecb95 100644
--- a/tests/test_utils/test_commons.py
+++ b/tests/test_legacy/test_zero/test_commons.py
@@ -1,13 +1,13 @@
import torch
import colossalai
+from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
+from colossalai.legacy.zero.sharded_param import ShardedTensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
-from colossalai.zero.legacy.sharded_param import ShardedTensor
def run_tensor_move(rank, world_size, port):
- colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl')
+ colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl')
src_t = torch.ones(2, 3).cuda()
tgt_t = torch.zeros(2, 3)
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index 39603c158731..c096b6075005 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -3,9 +3,9 @@
import torch.nn as nn
import colossalai
-from colossalai.context import ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.core import global_context as gpc
+from colossalai.legacy.context import ParallelMode
+from colossalai.legacy.core import global_context as gpc
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
index a43ae764dccd..35fde6f10f3f 100644
--- a/tests/test_moe/test_moe_zero_optim.py
+++ b/tests/test_moe/test_moe_zero_optim.py
@@ -2,8 +2,8 @@
import torch
import colossalai
-from colossalai.amp import convert_to_apex_amp
from colossalai.context import MOE_CONTEXT
+from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam
diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py
index 2c68633aabc8..4a3199c1c53d 100644
--- a/tests/test_tensor/test_comm_spec_apply.py
+++ b/tests/test_tensor/test_comm_spec_apply.py
@@ -1,7 +1,7 @@
import pytest
import torch
+import torch.distributed as dist
-from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
@@ -184,7 +184,7 @@ def check_comm(rank, world_size, port):
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
- assert rank == gpc.get_global_rank()
+ assert rank == dist.get_rank()
mesh_shape = (2, 2)
# [[0, 1,
@@ -205,7 +205,6 @@ def check_comm(rank, world_size, port):
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(device_mesh, rank)
- gpc.destroy()
@pytest.mark.dist
diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py
index 95fcd2aaf8f3..a1ea2946e6e7 100644
--- a/tests/test_tensor/test_dtensor/test_comm_spec.py
+++ b/tests/test_tensor/test_dtensor/test_comm_spec.py
@@ -1,7 +1,7 @@
import pytest
import torch
+import torch.distributed as dist
-from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
@@ -127,7 +127,7 @@ def check_comm(rank, world_size, port):
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
- assert rank == gpc.get_global_rank()
+ assert rank == dist.get_rank()
mesh_shape = (2, 2)
# [[0, 1,
@@ -149,8 +149,6 @@ def check_comm(rank, world_size, port):
check_all_reduce_fwd(process_group_dict, rank)
check_all_reduce_bwd(process_group_dict, rank)
- gpc.destroy()
-
@pytest.mark.dist
@rerun_if_address_is_in_use()
diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py
index 9122808eb5a3..bd71bffccc70 100644
--- a/tests/test_tensor/test_mix_gather.py
+++ b/tests/test_tensor/test_mix_gather.py
@@ -1,7 +1,7 @@
import pytest
import torch
+import torch.distributed as dist
-from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
@@ -295,7 +295,7 @@ def check_comm(rank, world_size, port):
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 8)
- assert rank == gpc.get_global_rank()
+ assert rank == dist.get_rank()
mesh_shape = (2, 4)
# [[0, 1, 2, 3],
diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py
deleted file mode 100644
index e99cf388e929..000000000000
--- a/tests/test_utils/test_zero_gradient_clippling.py
+++ /dev/null
@@ -1,111 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from functools import partial
-
-import pytest
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.nn.utils import clip_grad_norm_
-
-import colossalai
-from colossalai.logging import disable_existing_loggers
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import checkpoint, clip_grad_norm_fp32
-from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
-from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
-
-
-def checkpoint_wrapper(module, enable=True):
- if enable:
- module.forward = partial(checkpoint, module.forward, False)
- return module
-
-
-class Net(nn.Module):
-
- def __init__(self, checkpoint=False) -> None:
- super().__init__()
- self.fc1 = nn.Linear(5, 5)
- self.fc2 = nn.Linear(5, 5)
- self.fc3 = nn.Linear(5, 1)
- if checkpoint:
- self.fc1 = checkpoint_wrapper(self.fc1)
- self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
-
- def forward(self, x):
- for layer in self.layers:
- x = layer(x)
- return x
-
-
-def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
- model.train()
- optimizer.zero_grad()
- with torch.cuda.amp.autocast(enabled=enable_autocast):
- y = model(x)
- loss = y.sum()
- loss = loss.float()
- loss.backward()
- clip_grad(model, norm_type)
- optimizer.step()
-
-
-def clip_grad(model, norm_type):
- if isinstance(model, DDP):
- clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
- else:
- clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
-
-
-def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
- if loose:
- return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
- return torch.allclose(tensor_a, tensor_b)
-
-
-def check_grads(model, zero_model, loose=False):
- rank = dist.get_rank()
- for p, zero_p in zip(model.parameters(), zero_model.parameters()):
- zero_grad = zero_p.grad.clone().to(p.device)
- chunks = torch.flatten(p.grad).chunk(4)
- if rank >= len(chunks):
- continue
- grad = chunks[rank]
- if zero_p.zero_shard_padding > 0:
- zero_grad = zero_grad[:-zero_p.zero_shard_padding]
- assert grad.dtype == zero_grad.dtype
- assert allclose(grad, zero_grad, loose=loose)
-
-
-def check_params(model, zero_model, loose=False):
- rank = dist.get_rank()
- for p, zero_p in zip(model.parameters(), zero_model.parameters()):
- zero_shard_padding = zero_p.zero_shard_padding
- zero_p = zero_p.clone().to(p.device)
- chunks = torch.flatten(p).chunk(4)
- if rank >= len(chunks):
- continue
- p = chunks[rank]
- if zero_shard_padding > 0:
- zero_p = zero_p[:-zero_shard_padding]
- assert p.dtype == zero_p.dtype
- assert allclose(p, zero_p, loose=loose)
-
-
-def run_dist(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_zero_clip_grad():
- world_size = 4
- spawn(run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_zero_clip_grad()
diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py
index d6c4f8bd8aac..f05ccfdbd41b 100644
--- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py
+++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py
@@ -6,7 +6,6 @@
from colossalai.tensor import ColoTensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero.gemini.chunk import ChunkManager
-from tests.test_tensor.common_utils import debug_print
CUDA_MEM_0 = {False: 512, True: 1024}
CUDA_MEM_1 = {False: 0, True: 1024}
@@ -16,7 +15,6 @@
@parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False])
def exam_chunk_memory(keep_gathered, pin_memory):
- debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)]
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py
index 4cbf564ecfb9..fabdd6072c31 100644
--- a/tests/test_zero/test_gemini/test_fwd_bwd.py
+++ b/tests/test_zero/test_gemini/test_fwd_bwd.py
@@ -5,15 +5,15 @@
from torch.testing import assert_close
import colossalai
-from colossalai.amp import convert_to_apex_amp
+from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py
index a80a2f62de22..614a96ccdbcd 100644
--- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py
+++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py
@@ -4,12 +4,12 @@
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
# run gemini use the runtime memory tracer
diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py
index 82b9133b89c1..860d6efa899a 100644
--- a/tests/test_zero/test_gemini/test_grad_clip.py
+++ b/tests/test_zero/test_gemini/test_grad_clip.py
@@ -5,14 +5,14 @@
from torch.testing import assert_close
import colossalai
-from colossalai.amp import convert_to_apex_amp
+from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py
index 20d145f9661f..99ee08c1d7e7 100644
--- a/tests/test_zero/test_gemini/test_inference.py
+++ b/tests/test_zero/test_gemini/test_inference.py
@@ -7,15 +7,15 @@
from torch.testing import assert_close
import colossalai
-from colossalai.amp import convert_to_apex_amp
+from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py
index edcbada0acbb..3454959199d2 100644
--- a/tests/test_zero/test_gemini/test_optim.py
+++ b/tests/test_zero/test_gemini/test_optim.py
@@ -5,15 +5,15 @@
from torch.testing import assert_close
import colossalai
-from colossalai.amp import convert_to_apex_amp
+from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py
index 656bd709e2a1..602e3ad3519d 100644
--- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py
+++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py
@@ -4,10 +4,10 @@
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py
index 09725e11ec0c..5f7b51510d58 100644
--- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py
+++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py
@@ -5,10 +5,10 @@
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
PLACEMENT_CONFIGS = [
{
diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py
deleted file mode 100644
index 4a2b49f63b7e..000000000000
--- a/tests/test_zero/test_low_level/test_zero_tp.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import pytest
-import torch
-import torch.nn as nn
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.testing import assert_close
-
-import colossalai
-from colossalai.tensor import ProcessGroup
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer
-from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal
-
-
-def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
- return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
-
-
-class MlpModel(nn.Module):
-
- def __init__(self):
- super(MlpModel, self).__init__()
- self.linear1 = nn.Linear(32, 128)
- self.act = nn.GELU()
- self.linear2 = nn.Linear(128, 32)
-
- def forward(self, x):
- y = self.linear1(x)
- y = self.act(y)
- y = self.linear2(y)
- return x + y
-
-
-@parameterize("overlap_flag", [False, True])
-@parameterize("partition_flag", [False, True])
-def exam_zero_with_tp(overlap_flag, partition_flag):
- set_seed(233010)
- tp_pg = ProcessGroup(tp_degree=2)
-
- with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
- hybrid_model = MlpModel()
- torch_model = MlpModel().cuda()
- for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
- pt.data.copy_(ph.data)
-
- for name, param in hybrid_model.named_parameters():
- if 'linear1' in name:
- split_param_row_tp1d(param, tp_pg)
- param.compute_spec.set_output_replicate(False)
- if 'linear2.weight' in name:
- split_param_col_tp1d(param, tp_pg)
-
- torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
- torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11
- hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2)
- hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
- initial_scale=2,
- clip_grad_norm=1.0,
- overlap_communication=overlap_flag,
- partition_grad=partition_flag,
- dp_process_group=tp_pg.dp_process_group(),
- tp_process_group=tp_pg.tp_process_group())
-
- dp_local_rank = tp_pg.dp_local_rank()
- set_seed(255 + dp_local_rank)
-
- data = torch.randn(8, 32, device=get_current_device())
- torch_loss = torch_model(data).sum()
- hybrid_loss = hybrid_model(data).sum()
- assert_close(torch_loss, hybrid_loss)
-
- torch_loss.backward()
- torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
- hybrid_optim.backward(hybrid_loss)
-
- torch_optim.step()
- hybrid_optim.step()
-
- for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()):
- assert strict_shard_equal(pt.data, ph.data, tp_pg)
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
- exam_zero_with_tp()
-
-
-@pytest.mark.skip('this will be rewritten by shardformer')
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_zero_with_tp():
- spawn(run_dist, 4)
-
-
-if __name__ == '__main__':
- test_zero_with_tp()
From 3c6b831c264d0657a97034b5cf036c913a762083 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Mon, 18 Sep 2023 16:52:42 +0800
Subject: [PATCH 19/58] [format] applied code formatting on changed files in
pull request 4743 (#4750)
Co-authored-by: github-actions
---
.../language/gpt/experiments/pipeline_parallel/train_gpt_pp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
index 30d6aab4f12f..749243e57836 100644
--- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
+++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
@@ -3,6 +3,7 @@
from functools import partial
import torch
+from model_zoo import model_builder
from torch import nn
from tqdm import tqdm
@@ -18,7 +19,6 @@
from colossalai.legacy.pipeline.rpc.utils import rpc_run
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
-from model_zoo import model_builder
def parse_args():
From 079bf3cb26a502fc647b1aad15fd14d6266be66c Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Tue, 19 Sep 2023 14:20:26 +0800
Subject: [PATCH 20/58] [misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit
* [misc] run pre-commit
* [misc] remove useless configuration files
* [misc] ignore cuda for clang-format
---
.flake8 | 22 -
.github/workflows/scripts/check_doc_i18n.py | 12 +-
.../example_checks/check_dispatch_inputs.py | 10 +-
.../example_checks/check_example_weekly.py | 10 +-
.../example_checks/detect_changed_example.py | 6 +-
.../generate_leaderboard_and_send_to_lark.py | 137 +-
.../scripts/generate_release_draft.py | 51 +-
.../workflows/scripts/send_message_to_lark.py | 6 +-
.isort.cfg | 1 +
.pre-commit-config.yaml | 18 +-
.style.yapf | 5 -
.../benchmarks/benchmark_opt_lora_dummy.py | 212 +-
.../Chat/benchmarks/ray/1mmt_dummy.py | 116 +-
.../Chat/benchmarks/ray/mmmt_dummy.py | 128 +-
applications/Chat/coati/dataset/__init__.py | 9 +-
.../Chat/coati/dataset/conversation.py | 18 +-
.../Chat/coati/dataset/prompt_dataset.py | 20 +-
.../Chat/coati/dataset/reward_dataset.py | 102 +-
.../Chat/coati/dataset/sft_dataset.py | 127 +-
.../Chat/coati/experience_buffer/__init__.py | 2 +-
.../Chat/coati/experience_buffer/base.py | 6 +-
.../Chat/coati/experience_buffer/naive.py | 12 +-
.../Chat/coati/experience_buffer/utils.py | 15 +-
.../Chat/coati/experience_maker/__init__.py | 2 +-
.../Chat/coati/experience_maker/base.py | 11 +-
.../Chat/coati/experience_maker/naive.py | 11 +-
applications/Chat/coati/kernels/__init__.py | 4 +-
applications/Chat/coati/kernels/opt_attn.py | 23 +-
applications/Chat/coati/models/__init__.py | 11 +-
.../Chat/coati/models/base/__init__.py | 9 +-
applications/Chat/coati/models/base/actor.py | 13 +-
applications/Chat/coati/models/base/critic.py | 15 +-
.../Chat/coati/models/base/reward_model.py | 16 +-
.../Chat/coati/models/bloom/__init__.py | 2 +-
.../Chat/coati/models/bloom/bloom_actor.py | 17 +-
.../Chat/coati/models/bloom/bloom_critic.py | 17 +-
.../Chat/coati/models/bloom/bloom_rm.py | 14 +-
.../Chat/coati/models/chatglm/__init__.py | 2 +-
.../coati/models/chatglm/chatglm_actor.py | 13 +-
.../coati/models/chatglm/chatglm_tokenizer.py | 118 +-
.../models/chatglm/configuration_chatglm.py | 48 +-
.../coati/models/chatglm/modeling_chatglm.py | 406 +-
applications/Chat/coati/models/generation.py | 101 +-
.../Chat/coati/models/gpt/__init__.py | 2 +-
.../Chat/coati/models/gpt/gpt_actor.py | 16 +-
.../Chat/coati/models/gpt/gpt_critic.py | 14 +-
applications/Chat/coati/models/gpt/gpt_rm.py | 12 +-
.../Chat/coati/models/llama/__init__.py | 2 +-
.../Chat/coati/models/llama/llama_actor.py | 18 +-
.../Chat/coati/models/llama/llama_critic.py | 15 +-
.../Chat/coati/models/llama/llama_rm.py | 15 +-
applications/Chat/coati/models/lora.py | 32 +-
applications/Chat/coati/models/loss.py | 28 +-
.../Chat/coati/models/opt/__init__.py | 2 +-
.../Chat/coati/models/opt/opt_actor.py | 14 +-
.../Chat/coati/models/opt/opt_critic.py | 14 +-
applications/Chat/coati/models/opt/opt_rm.py | 12 +-
applications/Chat/coati/models/utils.py | 20 +-
applications/Chat/coati/quant/__init__.py | 4 +-
.../Chat/coati/quant/llama_gptq/__init__.py | 2 +-
.../Chat/coati/quant/llama_gptq/loader.py | 5 +-
.../coati/quant/llama_gptq/model_utils.py | 5 +-
.../Chat/coati/quant/llama_gptq/quant.py | 36 +-
applications/Chat/coati/quant/utils.py | 3 +-
applications/Chat/coati/ray/callbacks/base.py | 3 +-
.../ray/callbacks/performance_evaluator.py | 58 +-
.../Chat/coati/ray/detached_replay_buffer.py | 25 +-
.../Chat/coati/ray/detached_trainer_base.py | 38 +-
.../Chat/coati/ray/detached_trainer_ppo.py | 85 +-
.../Chat/coati/ray/experience_maker_holder.py | 113 +-
.../Chat/coati/ray/lora_constructor.py | 53 +-
applications/Chat/coati/ray/utils.py | 80 +-
applications/Chat/coati/trainer/__init__.py | 6 +-
applications/Chat/coati/trainer/base.py | 14 +-
.../Chat/coati/trainer/callbacks/__init__.py | 2 +-
.../Chat/coati/trainer/callbacks/base.py | 2 +-
.../callbacks/performance_evaluator.py | 48 +-
.../trainer/callbacks/save_checkpoint.py | 29 +-
applications/Chat/coati/trainer/ppo.py | 101 +-
applications/Chat/coati/trainer/rm.py | 18 +-
applications/Chat/coati/trainer/sft.py | 51 +-
.../Chat/coati/trainer/strategies/__init__.py | 5 +-
.../Chat/coati/trainer/strategies/base.py | 24 +-
.../coati/trainer/strategies/colossalai.py | 113 +-
.../Chat/coati/trainer/strategies/ddp.py | 62 +-
.../Chat/coati/trainer/strategies/sampler.py | 7 +-
applications/Chat/coati/trainer/utils.py | 1 -
.../Chat/evaluate/config/config_cn.json | 24 +-
applications/Chat/evaluate/eval.py | 88 +-
applications/Chat/evaluate/evaluator.py | 46 +-
applications/Chat/evaluate/gpt_evaluate.py | 136 +-
applications/Chat/evaluate/metrics.py | 39 +-
.../Chat/evaluate/unieval/__init__.py | 7 +-
.../Chat/evaluate/unieval/evaluator.py | 234 +-
applications/Chat/evaluate/unieval/scorer.py | 47 +-
applications/Chat/evaluate/unieval/utils.py | 165 +-
applications/Chat/evaluate/utils.py | 9 +-
.../examples/community/peft/easy_dataset.py | 108 +-
.../examples/community/peft/easy_models.py | 39 +-
.../community/peft/train_peft_prompts.py | 156 +-
.../examples/community/peft/train_peft_sft.py | 171 +-
.../examples/community/ray/ray_job_script.py | 25 +-
.../community/ray/train_prompts_on_ray.py | 250 +-
applications/Chat/examples/download_model.py | 19 +-
.../examples/generate_conversation_dataset.py | 33 +-
.../Chat/examples/generate_prompt_dataset.py | 18 +-
applications/Chat/examples/inference.py | 51 +-
applications/Chat/examples/ray/1mmt_prompt.py | 102 +-
applications/Chat/examples/ray/mmmt_prompt.py | 116 +-
applications/Chat/examples/requirements.txt | 2 +-
applications/Chat/examples/train_prompts.py | 174 +-
.../Chat/examples/train_reward_model.py | 205 +-
applications/Chat/examples/train_sft.py | 233 +-
applications/Chat/inference/benchmark.py | 48 +-
applications/Chat/inference/locustfile.py | 30 +-
applications/Chat/inference/server.py | 99 +-
.../Chat/inference/tests/test_chat_prompt.py | 70 +-
applications/Chat/inference/utils.py | 91 +-
applications/Chat/requirements-test.txt | 2 +-
applications/Chat/setup.py | 42 +-
applications/Chat/tests/test_checkpoint.py | 29 +-
applications/Chat/tests/test_dataset.py | 115 +-
applications/Chat/tests/test_experience.py | 44 +-
applications/Chat/tests/test_models.py | 153 +-
colossalai/__init__.py | 6 +-
.../_subclasses/_meta_registration.py | 161 +-
.../_analyzer/_subclasses/_monkey_patch.py | 3 +-
.../_analyzer/_subclasses/flop_tensor.py | 77 +-
.../_analyzer/_subclasses/meta_tensor.py | 46 +-
colossalai/_analyzer/fx/codegen.py | 181 +-
colossalai/_analyzer/fx/graph_module.py | 54 +-
colossalai/_analyzer/fx/node_util.py | 54 +-
.../_analyzer/fx/passes/graph_profile.py | 104 +-
colossalai/_analyzer/fx/passes/shape_prop.py | 36 +-
colossalai/_analyzer/fx/symbolic_profile.py | 4 -
.../_analyzer/fx/tracer/bias_addition.py | 190 +-
.../_analyzer/fx/tracer/custom_leaf_module.py | 1 +
colossalai/_analyzer/fx/tracer/proxy.py | 15 +-
.../_analyzer/fx/tracer/symbolic_trace.py | 12 +-
colossalai/_analyzer/fx/tracer/tracer.py | 106 +-
.../amp/naive_amp/grad_scaler/__init__.py | 2 +-
.../naive_amp/grad_scaler/base_grad_scaler.py | 17 +-
.../grad_scaler/constant_grad_scaler.py | 3 +-
.../grad_scaler/dynamic_grad_scaler.py | 64 +-
.../mixed_precision_mixin/__init__.py | 6 +-
.../naive_amp/mixed_precision_mixin/base.py | 9 +-
.../naive_amp/mixed_precision_mixin/fp16.py | 37 +-
.../naive_amp/mixed_precision_optimizer.py | 101 +-
.../auto_parallel/checkpoint/build_c_ext.py | 16 +-
.../checkpoint/ckpt_solver_base.py | 35 +-
.../checkpoint/ckpt_solver_chen.py | 7 +-
.../checkpoint/ckpt_solver_rotor.py | 118 +-
.../auto_parallel/checkpoint/operation.py | 39 +-
.../auto_parallel/meta_profiler/constants.py | 2 -
.../meta_profiler/meta_registry/activation.py | 43 +-
.../meta_registry/binary_elementwise_ops.py | 6 +-
.../meta_profiler/meta_registry/conv.py | 62 +-
.../meta_profiler/meta_registry/embedding.py | 12 +-
.../meta_profiler/meta_registry/linear.py | 307 +-
.../meta_profiler/meta_registry/non_spmd.py | 2 +-
.../meta_profiler/meta_registry/norm.py | 102 +-
.../meta_profiler/meta_registry/pooling.py | 14 +-
.../meta_profiler/meta_registry/tensor.py | 43 +-
.../meta_profiler/meta_registry/where.py | 27 +-
.../auto_parallel/meta_profiler/registry.py | 8 +-
.../meta_profiler/shard_metainfo.py | 32 +-
.../auto_parallel/offload/amp_optimizer.py | 60 +-
.../offload/base_offload_module.py | 12 +-
.../auto_parallel/offload/mem_optimize.py | 14 +-
colossalai/auto_parallel/offload/region.py | 10 +-
.../auto_parallel/offload/region_manager.py | 137 +-
colossalai/auto_parallel/offload/runtime.py | 68 +-
colossalai/auto_parallel/offload/solver.py | 105 +-
.../offload/training_simulator.py | 130 +-
colossalai/auto_parallel/offload/util.py | 22 +-
.../passes/comm_metainfo_pass.py | 57 +-
.../auto_parallel/passes/meta_info_prop.py | 22 +-
.../passes/runtime_apply_pass.py | 129 +-
.../passes/runtime_preparation_pass.py | 167 +-
.../auto_parallel/tensor_shard/constants.py | 56 +-
.../auto_parallel/tensor_shard/initialize.py | 256 +-
.../tensor_shard/node_handler/__init__.py | 36 +-
.../node_handler/addmm_handler.py | 46 +-
.../node_handler/batch_norm_handler.py | 59 +-
.../binary_elementwise_handler.py | 48 +-
.../tensor_shard/node_handler/bmm_handler.py | 44 +-
.../tensor_shard/node_handler/conv_handler.py | 68 +-
.../node_handler/default_reshape_handler.py | 18 +-
.../node_handler/embedding_handler.py | 104 +-
.../node_handler/getattr_handler.py | 2 +-
.../node_handler/getitem_handler.py | 8 +-
.../node_handler/layer_norm_handler.py | 30 +-
.../node_handler/linear_handler.py | 175 +-
.../node_handler/matmul_handler.py | 134 +-
.../tensor_shard/node_handler/node_handler.py | 89 +-
.../node_handler/normal_pooling_handler.py | 10 +-
.../node_handler/output_handler.py | 11 +-
.../node_handler/permute_handler.py | 14 +-
.../node_handler/placeholder_handler.py | 10 +-
.../tensor_shard/node_handler/registry.py | 6 +-
.../node_handler/softmax_handler.py | 8 +-
.../node_handler/split_handler.py | 8 +-
.../node_handler/strategy/__init__.py | 34 +-
.../strategy/batch_norm_generator.py | 176 +-
.../strategy/binary_elementwise_generator.py | 40 +-
.../strategy/conv_strategy_generator.py | 231 +-
.../strategy/embedding_generator.py | 116 +-
.../strategy/getattr_generator.py | 13 +-
.../strategy/getitem_generator.py | 54 +-
.../strategy/layer_norm_generator.py | 57 +-
.../strategy/matmul_strategy_generator.py | 688 +-
.../strategy/normal_pooling_generator.py | 35 +-
.../node_handler/strategy/output_generator.py | 51 +-
.../strategy/placeholder_generator.py | 43 +-
.../strategy/reshape_generator.py | 97 +-
.../strategy/softmax_generator.py | 51 +-
.../strategy/strategy_generator.py | 106 +-
.../node_handler/strategy/sum_generator.py | 54 +-
.../strategy/tensor_constructor_generator.py | 29 +-
.../strategy/unary_elementwise_generator.py | 23 +-
.../node_handler/strategy/where_generator.py | 28 +-
.../tensor_shard/node_handler/sum_handler.py | 8 +-
.../tensor_constructor_handler.py | 2 +-
.../node_handler/transpose_handler.py | 10 +-
.../node_handler/unary_elementwise_handler.py | 10 +-
.../tensor_shard/node_handler/view_handler.py | 6 +-
.../node_handler/where_handler.py | 33 +-
.../auto_parallel/tensor_shard/options.py | 6 +-
.../tensor_shard/sharding_strategy.py | 46 +-
.../tensor_shard/solver/__init__.py | 2 +-
.../tensor_shard/solver/cost_graph.py | 18 +-
.../tensor_shard/solver/graph_analysis.py | 25 +-
.../tensor_shard/solver/solver.py | 128 +-
.../solver/strategies_constructor.py | 118 +-
.../tensor_shard/utils/__init__.py | 22 +-
.../tensor_shard/utils/broadcast.py | 42 +-
.../tensor_shard/utils/factory.py | 57 +-
.../auto_parallel/tensor_shard/utils/misc.py | 27 +-
.../tensor_shard/utils/reshape.py | 16 +-
.../tensor_shard/utils/sharding.py | 23 +-
colossalai/autochunk/autochunk_codegen.py | 162 +-
colossalai/autochunk/estimate_memory.py | 24 +-
colossalai/autochunk/search_chunk.py | 39 +-
colossalai/autochunk/select_chunk.py | 70 +-
colossalai/autochunk/trace_flow.py | 36 +-
colossalai/autochunk/trace_indice.py | 30 +-
colossalai/autochunk/utils.py | 24 +-
colossalai/booster/accelerator.py | 18 +-
colossalai/booster/booster.py | 103 +-
.../booster/mixed_precision/__init__.py | 22 +-
.../booster/mixed_precision/fp16_apex.py | 26 +-
.../booster/mixed_precision/fp16_naive.py | 18 +-
.../booster/mixed_precision/fp16_torch.py | 77 +-
colossalai/booster/plugin/__init__.py | 7 +-
colossalai/booster/plugin/dp_plugin_base.py | 38 +-
colossalai/booster/plugin/gemini_plugin.py | 138 +-
.../booster/plugin/hybrid_parallel_plugin.py | 428 +-
.../booster/plugin/low_level_zero_plugin.py | 109 +-
colossalai/booster/plugin/plugin_base.py | 29 +-
colossalai/booster/plugin/pp_plugin_base.py | 17 +-
colossalai/booster/plugin/torch_ddp_plugin.py | 75 +-
.../booster/plugin/torch_fsdp_plugin.py | 69 +-
colossalai/checkpoint_io/__init__.py | 2 +-
.../checkpoint_io/checkpoint_io_base.py | 69 +-
.../checkpoint_io/general_checkpoint_io.py | 100 +-
.../hybrid_parallel_checkpoint_io.py | 297 +-
colossalai/checkpoint_io/index_file.py | 6 +-
colossalai/checkpoint_io/utils.py | 165 +-
colossalai/cli/__init__.py | 2 +-
colossalai/cli/check/__init__.py | 5 +-
colossalai/cli/check/check_installation.py | 29 +-
colossalai/cli/cli.py | 5 +-
colossalai/cli/launcher/__init__.py | 99 +-
colossalai/cli/launcher/hostinfo.py | 5 +-
colossalai/cli/launcher/multinode_runner.py | 22 +-
colossalai/cli/launcher/run.py | 70 +-
colossalai/cluster/__init__.py | 2 +-
colossalai/cluster/device_mesh_manager.py | 21 +-
colossalai/cluster/dist_coordinator.py | 12 +-
colossalai/cluster/process_group_manager.py | 8 +-
colossalai/cluster/process_group_mesh.py | 23 +-
colossalai/context/__init__.py | 4 +-
colossalai/context/config.py | 11 +-
colossalai/context/moe_context.py | 26 +-
colossalai/context/singleton_meta.py | 5 +-
colossalai/device/__init__.py | 2 +-
colossalai/device/alpha_beta_profiler.py | 71 +-
colossalai/device/calc_pipeline_strategy.py | 58 +-
colossalai/device/device_mesh.py | 113 +-
colossalai/fx/_compatibility.py | 8 +-
colossalai/fx/_meta_regist_12.py | 144 +-
.../codegen/activation_checkpoint_codegen.py | 368 +-
colossalai/fx/graph_module.py | 56 +-
.../fx/passes/adding_split_node_pass.py | 81 +-
colossalai/fx/passes/concrete_info_prop.py | 82 +-
.../adding_shape_consistency_pass.py | 63 +-
colossalai/fx/passes/meta_info_prop.py | 97 +-
colossalai/fx/passes/passes_for_gpt2_test.py | 83 +-
colossalai/fx/passes/shard_1d_pass.py | 47 +-
colossalai/fx/passes/split_module.py | 67 +-
colossalai/fx/passes/utils.py | 29 +-
colossalai/fx/profiler/__init__.py | 11 +-
colossalai/fx/profiler/constants.py | 2 +-
colossalai/fx/profiler/dataflow.py | 19 +-
.../fx/profiler/experimental/constants.py | 34 +-
.../fx/profiler/experimental/profiler.py | 27 +-
.../profiler_function/activation_function.py | 2 +
.../profiler_function/arithmetic.py | 36 +-
.../profiler_function/embedding.py | 4 +-
.../experimental/profiler_function/linear.py | 2 +
.../profiler_function/normalization.py | 14 +-
.../experimental/profiler_function/pooling.py | 4 +-
.../profiler_function/python_ops.py | 2 +-
.../profiler_function/torch_ops.py | 14 +-
.../profiler_module/activation_function.py | 2 +
.../experimental/profiler_module/attention.py | 40 +-
.../profiler_module/convolution.py | 86 +-
.../experimental/profiler_module/dropout.py | 2 +
.../experimental/profiler_module/linear.py | 2 +
.../profiler_module/normalization.py | 9 +-
.../experimental/profiler_module/pooling.py | 2 +
.../experimental/profiler_module/rnn.py | 27 +-
.../experimental/profiler_module/torch_op.py | 5 +-
.../fx/profiler/experimental/registry.py | 6 +-
.../fx/profiler/experimental/shard_utils.py | 8 +-
colossalai/fx/profiler/memory_utils.py | 5 +-
colossalai/fx/profiler/opcount.py | 37 +-
colossalai/fx/profiler/profiler.py | 52 +-
colossalai/fx/profiler/shard_utils.py | 4 +-
colossalai/fx/profiler/tensor.py | 31 +-
colossalai/fx/proxy.py | 16 +-
colossalai/fx/tracer/_meta_trace.py | 54 +-
colossalai/fx/tracer/_tracer_utils.py | 7 +-
.../patched_bias_addition_function/addbmm.py | 28 +-
.../patched_bias_addition_function/addmm.py | 26 +-
.../bias_addition_function.py | 8 +-
.../patched_bias_addition_function/linear.py | 12 +-
.../bias_addition_module.py | 14 +-
.../patched_bias_addition_module/conv.py | 21 +-
.../patched_bias_addition_module/linear.py | 2 -
colossalai/fx/tracer/experimental.py | 229 +-
.../patched_function/activation_function.py | 2 +-
.../meta_patch/patched_function/arithmetic.py | 12 +-
.../patched_function/convolution.py | 59 +-
.../meta_patch/patched_function/embedding.py | 10 +-
.../patched_function/normalization.py | 15 +-
.../meta_patch/patched_function/python_ops.py | 4 +-
.../meta_patch/patched_function/torch_ops.py | 25 +-
.../meta_patch/patched_module/__init__.py | 2 +-
.../patched_module/activation_function.py | 2 +-
.../meta_patch/patched_module/convolution.py | 96 +-
.../meta_patch/patched_module/embedding.py | 2 +-
.../meta_patch/patched_module/linear.py | 4 +-
.../patched_module/normalization.py | 1 +
.../meta_patch/patched_module/pooling.py | 30 +-
.../tracer/meta_patch/patched_module/rnn.py | 12 +-
colossalai/fx/tracer/registry.py | 12 +-
colossalai/fx/tracer/tracer.py | 71 +-
.../inference/tensor_parallel/__init__.py | 2 +-
.../tensor_parallel/batch_infer_state.py | 15 +-
.../inference/tensor_parallel/engine.py | 77 +-
.../tensor_parallel/kvcache_manager.py | 51 +-
.../tensor_parallel/modeling/__init__.py | 2 +-
.../tensor_parallel/modeling/bloom.py | 143 +-
.../tensor_parallel/modeling/llama.py | 106 +-
.../tensor_parallel/policies/__init__.py | 2 +-
.../tensor_parallel/policies/bloom.py | 33 +-
.../tensor_parallel/policies/llama.py | 28 +-
colossalai/initialize.py | 145 +-
colossalai/interface/__init__.py | 2 +-
colossalai/interface/model.py | 4 +-
colossalai/interface/optimizer.py | 22 +-
colossalai/kernel/cuda_native/__init__.py | 8 +-
colossalai/kernel/cuda_native/csrc/compat.h | 2 +-
.../cuda_native/csrc/kernels/cuda_util.cu | 1 -
.../csrc/kernels/dropout_kernels.cu | 2004 ++---
.../csrc/kernels/general_kernels.cu | 464 +-
.../csrc/kernels/include/dropout.h | 192 +-
.../csrc/kernels/include/kernels.h | 27 +-
.../csrc/kernels/include/normalize_layer.h | 129 +-
.../csrc/kernels/include/softmax.h | 84 +-
.../csrc/kernels/normalize_kernels.cu | 2341 +++---
.../csrc/kernels/softmax_kernels.cu | 730 +-
.../csrc/kernels/transform_kernels.cu | 626 +-
.../cuda_native/csrc/layer_norm_cuda.cpp | 2 +-
.../csrc/layer_norm_cuda_kernel.cu | 2 +-
.../kernel/cuda_native/csrc/moe_cuda.cpp | 194 +-
.../cuda_native/csrc/moe_cuda_kernel.cu | 1318 ++--
.../csrc/multi_tensor_l2norm_kernel.cu | 2 +-
.../cuda_native/csrc/multi_tensor_lamb.cu | 2 +-
.../csrc/multi_tensor_scale_kernel.cu | 2 +-
.../csrc/multi_tensor_sgd_kernel.cu | 2 +-
.../csrc/scaled_masked_softmax.cpp | 84 +-
.../cuda_native/csrc/scaled_masked_softmax.h | 868 +-
.../csrc/scaled_upper_triang_masked_softmax.h | 928 ++-
colossalai/kernel/cuda_native/layer_norm.py | 15 +-
colossalai/kernel/cuda_native/mha/__init__.py | 2 +-
.../kernel/cuda_native/mha/flash_attn_2.py | 46 +-
.../kernel/cuda_native/mha/mem_eff_attn.py | 35 +-
colossalai/kernel/cuda_native/mha/mha.py | 70 +-
colossalai/kernel/cuda_native/mha/utils.py | 8 +-
.../kernel/cuda_native/multihead_attention.py | 169 +-
.../kernel/cuda_native/scaled_softmax.py | 22 +-
colossalai/kernel/jit/__init__.py | 10 +-
colossalai/kernel/jit/bias_dropout_add.py | 15 +-
colossalai/kernel/jit/bias_gelu.py | 1 -
colossalai/kernel/jit/option.py | 21 +-
colossalai/kernel/triton/__init__.py | 11 +-
colossalai/kernel/triton/context_attention.py | 146 +-
.../kernel/triton/copy_kv_cache_dest.py | 36 +-
colossalai/kernel/triton/fused_layernorm.py | 35 +-
colossalai/kernel/triton/qkv_matmul_kernel.py | 54 +-
colossalai/kernel/triton/rms_norm.py | 21 +-
.../kernel/triton/rotary_embedding_kernel.py | 46 +-
.../kernel/triton/self_attention_nofusion.py | 29 +-
colossalai/kernel/triton/softmax.py | 65 +-
.../kernel/triton/token_attention_kernel.py | 211 +-
colossalai/lazy/__init__.py | 4 +-
colossalai/lazy/lazy_init.py | 189 +-
colossalai/legacy/__init__.py | 10 +-
colossalai/legacy/amp/__init__.py | 5 +-
colossalai/legacy/amp/amp_type.py | 6 +-
colossalai/legacy/amp/apex_amp/__init__.py | 3 +-
colossalai/legacy/amp/apex_amp/apex_amp.py | 2 +-
colossalai/legacy/amp/naive_amp/__init__.py | 4 +-
.../legacy/amp/naive_amp/_fp16_optimizer.py | 75 +-
colossalai/legacy/amp/naive_amp/_utils.py | 2 +-
colossalai/legacy/amp/naive_amp/naive_amp.py | 20 +-
colossalai/legacy/amp/torch_amp/__init__.py | 9 +-
.../legacy/amp/torch_amp/_grad_scaler.py | 86 +-
colossalai/legacy/amp/torch_amp/torch_amp.py | 3 +-
colossalai/legacy/builder/__init__.py | 2 +-
colossalai/legacy/builder/builder.py | 16 +-
colossalai/legacy/communication/__init__.py | 34 +-
colossalai/legacy/communication/collective.py | 32 +-
colossalai/legacy/communication/p2p.py | 244 +-
colossalai/legacy/communication/p2p_v2.py | 16 +-
colossalai/legacy/communication/ring.py | 17 +-
colossalai/legacy/communication/utils.py | 6 +-
colossalai/legacy/constants.py | 40 +-
colossalai/legacy/context/parallel_context.py | 98 +-
colossalai/legacy/context/parallel_mode.py | 37 +-
.../process_group_initializer/__init__.py | 12 +-
.../initializer_1d.py | 2 +-
.../initializer_2d.py | 16 +-
.../initializer_2p5d.py | 46 +-
.../initializer_3d.py | 24 +-
.../initializer_data.py | 2 +-
.../initializer_model.py | 2 +-
.../initializer_pipeline.py | 18 +-
.../initializer_sequence.py | 12 +-
.../initializer_tensor.py | 2 +-
.../process_group_initializer.py | 11 +-
colossalai/legacy/context/random/__init__.py | 13 +-
colossalai/legacy/context/random/_helper.py | 3 +-
.../legacy/context/random/seed_manager.py | 6 +-
colossalai/legacy/core.py | 2 +-
colossalai/legacy/engine/__init__.py | 2 +-
colossalai/legacy/engine/_base_engine.py | 42 +-
.../engine/gradient_accumulation/__init__.py | 21 +-
.../_gradient_accumulation.py | 5 +-
.../engine/gradient_handler/__init__.py | 8 +-
.../_base_gradient_handler.py | 1 -
.../_data_parallel_gradient_handler.py | 3 +-
.../gradient_handler/_moe_gradient_handler.py | 5 +-
.../_pipeline_parallel_gradient_handler.py | 18 +-
.../_sequence_parallel_gradient_handler.py | 3 +-
.../_zero_gradient_handler.py | 3 +-
colossalai/legacy/engine/schedule/__init__.py | 2 +-
.../legacy/engine/schedule/_base_schedule.py | 39 +-
.../engine/schedule/_non_pipeline_schedule.py | 28 +-
.../engine/schedule/_pipeline_schedule.py | 316 +-
.../engine/schedule/_pipeline_schedule_v2.py | 40 +-
colossalai/legacy/global_variables.py | 54 +-
colossalai/legacy/initialize.py | 338 +-
colossalai/legacy/nn/_ops/_utils.py | 33 +-
colossalai/legacy/nn/layer/base_layer.py | 56 +-
.../nn/layer/colossalai_layer/__init__.py | 2 +-
.../nn/layer/colossalai_layer/_utils.py | 5 +-
.../nn/layer/colossalai_layer/dropout.py | 2 +-
.../nn/layer/colossalai_layer/embedding.py | 53 +-
.../nn/layer/colossalai_layer/linear.py | 63 +-
.../legacy/nn/layer/parallel_1d/__init__.py | 12 +-
.../legacy/nn/layer/parallel_1d/_operation.py | 15 +-
.../legacy/nn/layer/parallel_1d/_utils.py | 7 +-
.../legacy/nn/layer/parallel_1d/layers.py | 551 +-
.../legacy/nn/layer/parallel_2d/__init__.py | 11 +-
.../legacy/nn/layer/parallel_2d/_operation.py | 303 +-
.../legacy/nn/layer/parallel_2d/_utils.py | 16 +-
.../legacy/nn/layer/parallel_2d/layers.py | 615 +-
.../legacy/nn/layer/parallel_2p5d/__init__.py | 11 +-
.../nn/layer/parallel_2p5d/_operation.py | 510 +-
.../legacy/nn/layer/parallel_2p5d/_utils.py | 27 +-
.../legacy/nn/layer/parallel_2p5d/layers.py | 562 +-
.../legacy/nn/layer/parallel_3d/__init__.py | 12 +-
.../legacy/nn/layer/parallel_3d/_operation.py | 65 +-
.../legacy/nn/layer/parallel_3d/_utils.py | 23 +-
.../legacy/nn/layer/parallel_3d/layers.py | 509 +-
.../nn/layer/parallel_sequence/__init__.py | 2 +-
.../nn/layer/parallel_sequence/_operation.py | 31 +-
.../nn/layer/parallel_sequence/layers.py | 119 +-
colossalai/legacy/nn/layer/utils/__init__.py | 10 +-
colossalai/legacy/nn/layer/utils/common.py | 9 +-
.../legacy/nn/layer/vanilla/__init__.py | 9 +-
colossalai/legacy/nn/layer/vanilla/layers.py | 93 +-
.../legacy/nn/layer/wrapper/__init__.py | 2 +-
.../nn/layer/wrapper/pipeline_wrapper.py | 17 +-
colossalai/legacy/nn/loss/__init__.py | 19 +-
colossalai/legacy/nn/loss/loss_1d.py | 4 +-
colossalai/legacy/nn/loss/loss_2d.py | 12 +-
colossalai/legacy/nn/loss/loss_2p5d.py | 12 +-
colossalai/legacy/nn/loss/loss_3d.py | 6 +-
colossalai/legacy/nn/metric/__init__.py | 7 +-
colossalai/legacy/nn/metric/accuracy_2d.py | 3 +-
colossalai/legacy/nn/metric/accuracy_2p5d.py | 3 +-
colossalai/legacy/nn/metric/accuracy_3d.py | 5 +-
colossalai/legacy/nn/parallel/__init__.py | 2 +-
.../legacy/nn/parallel/data_parallel.py | 37 +-
.../legacy/nn/parallel/layers/__init__.py | 20 +-
.../layers/cache_embedding/__init__.py | 11 +-
.../layers/cache_embedding/base_embedding.py | 9 +-
.../layers/cache_embedding/cache_mgr.py | 144 +-
.../cache_embedding/cached_embedding.py | 116 +-
.../parallel/layers/cache_embedding/copyer.py | 2 +-
.../cache_embedding/embedding_config.py | 22 +-
.../parallel_cached_embedding.py | 147 +-
.../parallel_cached_embedding_tablewise.py | 124 +-
..._cached_embedding_tablewise_split_cache.py | 91 +-
.../legacy/nn/parallel/layers/colo_module.py | 19 +-
.../legacy/nn/parallel/layers/embedding.py | 15 +-
.../legacy/nn/parallel/layers/linear.py | 21 +-
.../legacy/nn/parallel/layers/module_utils.py | 20 +-
colossalai/legacy/nn/parallel/reducer.py | 12 +-
colossalai/legacy/pipeline/__init__.py | 2 +-
colossalai/legacy/pipeline/layer_spec.py | 6 +-
.../legacy/pipeline/middleware/__init__.py | 2 +-
.../pipeline/middleware/adaptor/__init__.py | 2 +-
.../legacy/pipeline/middleware/adaptor/fx.py | 21 +-
colossalai/legacy/pipeline/middleware/topo.py | 56 +-
colossalai/legacy/pipeline/pipelinable.py | 15 +-
.../legacy/pipeline/pipeline_process_group.py | 32 +-
colossalai/legacy/pipeline/rpc/__init__.py | 2 +-
.../legacy/pipeline/rpc/_pipeline_base.py | 343 +-
.../legacy/pipeline/rpc/_pipeline_schedule.py | 151 +-
colossalai/legacy/pipeline/rpc/utils.py | 50 +-
colossalai/legacy/pipeline/utils.py | 26 +-
colossalai/legacy/registry/registry.py | 2 +-
colossalai/legacy/tensor/__init__.py | 16 +-
colossalai/legacy/tensor/compute_spec.py | 2 +-
colossalai/legacy/tensor/const.py | 2 +-
colossalai/legacy/tensor/dist_spec_mgr.py | 56 +-
colossalai/legacy/tensor/distspec.py | 13 +-
colossalai/legacy/tensor/process_group.py | 55 +-
colossalai/legacy/tensor/tensor_spec.py | 3 +-
colossalai/legacy/trainer/__init__.py | 2 +-
colossalai/legacy/trainer/_trainer.py | 5 +-
colossalai/legacy/trainer/hooks/__init__.py | 15 +-
colossalai/legacy/trainer/hooks/_base_hook.py | 46 +-
.../legacy/trainer/hooks/_checkpoint_hook.py | 37 +-
colossalai/legacy/trainer/hooks/_commons_.py | 4 +-
colossalai/legacy/trainer/hooks/_log_hook.py | 126 +-
.../trainer/hooks/_lr_scheduler_hook.py | 9 +-
.../legacy/trainer/hooks/_metric_hook.py | 71 +-
colossalai/legacy/utils/__init__.py | 48 +-
.../legacy/utils/activation_checkpoint.py | 35 +-
.../legacy/utils/checkpoint/__init__.py | 2 +-
.../utils/checkpoint/module_checkpoint.py | 68 +-
colossalai/legacy/utils/checkpoint/utils.py | 13 +-
colossalai/legacy/utils/checkpointing.py | 66 +-
colossalai/legacy/utils/common.py | 77 +-
.../legacy/utils/data_sampler/__init__.py | 2 +-
.../legacy/utils/data_sampler/base_sampler.py | 1 -
.../data_sampler/data_parallel_sampler.py | 57 +-
colossalai/legacy/utils/memory.py | 22 +-
colossalai/legacy/utils/profiler/extention.py | 1 -
.../legacy/utils/profiler/legacy/__init__.py | 2 +-
.../utils/profiler/legacy/comm_profiler.py | 99 +-
.../utils/profiler/legacy/pcie_profiler.py | 39 +-
.../utils/profiler/legacy/prof_utils.py | 34 +-
colossalai/legacy/utils/profiler/profiler.py | 62 +-
.../profiler/stateful_tensor_mem_extention.py | 25 +-
colossalai/legacy/zero/__init__.py | 21 +-
colossalai/legacy/zero/gemini/__init__.py | 9 +-
.../legacy/zero/gemini/gemini_context.py | 35 +-
.../zero/gemini/ophooks/_shard_grad_ophook.py | 2 +-
.../gemini/ophooks/_shard_param_ophook.py | 8 +-
.../gemini/ophooks/runtime_mem_tracer_hook.py | 14 +-
.../legacy/zero/gemini/ophooks/utils.py | 15 +-
.../zero/gemini/paramhooks/_param_hookmgr.py | 7 +-
.../legacy/zero/gemini/stateful_tensor.py | 14 +-
.../legacy/zero/gemini/stateful_tensor_mgr.py | 28 +-
.../zero/gemini/tensor_placement_policy.py | 44 +-
colossalai/legacy/zero/gemini/tensor_utils.py | 22 +-
colossalai/legacy/zero/init_ctx/__init__.py | 2 +-
.../legacy/zero/init_ctx/init_context.py | 57 +-
.../legacy/zero/shard_utils/__init__.py | 2 +-
.../zero/shard_utils/base_shard_strategy.py | 4 +-
.../bucket_tensor_shard_strategy.py | 5 +-
.../zero/shard_utils/tensor_shard_strategy.py | 8 +-
.../legacy/zero/sharded_model/__init__.py | 2 +-
.../legacy/zero/sharded_model/_utils.py | 2 +-
.../zero/sharded_model/reduce_scatter.py | 26 +-
.../zero/sharded_model/sharded_model_v2.py | 218 +-
colossalai/legacy/zero/sharded_model/utils.py | 2 +-
.../legacy/zero/sharded_model/zero_hook.py | 25 +-
.../legacy/zero/sharded_optim/__init__.py | 2 +-
.../zero/sharded_optim/sharded_optim_v2.py | 127 +-
.../legacy/zero/sharded_param/__init__.py | 2 +-
.../zero/sharded_param/sharded_param.py | 4 +-
.../zero/sharded_param/sharded_tensor.py | 1 -
colossalai/logging/__init__.py | 8 +-
colossalai/logging/logger.py | 35 +-
colossalai/nn/init.py | 52 +-
colossalai/nn/layer/moe/__init__.py | 15 +-
colossalai/nn/layer/moe/_operation.py | 12 +-
colossalai/nn/layer/moe/checkpoint.py | 12 +-
colossalai/nn/layer/moe/experts.py | 40 +-
colossalai/nn/layer/moe/layers.py | 64 +-
colossalai/nn/layer/moe/routers.py | 461 +-
colossalai/nn/layer/moe/utils.py | 139 +-
colossalai/nn/layer/utils.py | 5 +-
colossalai/nn/lr_scheduler/__init__.py | 19 +-
colossalai/nn/lr_scheduler/cosine.py | 31 +-
colossalai/nn/lr_scheduler/delayed.py | 39 +-
colossalai/nn/lr_scheduler/linear.py | 6 +-
colossalai/nn/lr_scheduler/multistep.py | 36 +-
colossalai/nn/lr_scheduler/onecycle.py | 52 +-
colossalai/nn/lr_scheduler/poly.py | 38 +-
colossalai/nn/optimizer/README.md | 2 +-
colossalai/nn/optimizer/__init__.py | 2 +-
colossalai/nn/optimizer/cpu_adam.py | 141 +-
colossalai/nn/optimizer/fused_adam.py | 80 +-
colossalai/nn/optimizer/fused_lamb.py | 147 +-
colossalai/nn/optimizer/fused_sgd.py | 55 +-
colossalai/nn/optimizer/hybrid_adam.py | 135 +-
colossalai/nn/optimizer/lamb.py | 30 +-
colossalai/nn/optimizer/lars.py | 38 +-
colossalai/nn/optimizer/nvme_optimizer.py | 28 +-
colossalai/pipeline/__init__.py | 10 +-
colossalai/pipeline/p2p.py | 29 +-
colossalai/pipeline/schedule/__init__.py | 6 +-
colossalai/pipeline/schedule/_utils.py | 23 +-
colossalai/pipeline/schedule/base.py | 17 +-
.../pipeline/schedule/interleaved_pp.py | 56 +-
colossalai/pipeline/schedule/one_f_one_b.py | 70 +-
colossalai/pipeline/stage_manager.py | 11 +-
colossalai/shardformer/_utils.py | 22 +-
.../examples/convergence_benchmark.py | 111 +-
colossalai/shardformer/examples/data.py | 33 +-
.../examples/performance_benchmark.py | 44 +-
colossalai/shardformer/layer/__init__.py | 16 +-
colossalai/shardformer/layer/_operation.py | 66 +-
colossalai/shardformer/layer/dropout.py | 11 +-
colossalai/shardformer/layer/embedding.py | 131 +-
colossalai/shardformer/layer/linear.py | 195 +-
colossalai/shardformer/layer/loss.py | 14 +-
colossalai/shardformer/layer/normalization.py | 53 +-
.../shardformer/layer/parallel_module.py | 45 +-
.../shardformer/layer/qkv_fused_linear.py | 284 +-
colossalai/shardformer/layer/utils.py | 16 +-
colossalai/shardformer/modeling/bert.py | 420 +-
colossalai/shardformer/modeling/blip2.py | 14 +-
colossalai/shardformer/modeling/bloom.py | 280 +-
colossalai/shardformer/modeling/chatglm2.py | 181 +-
.../chatglm2_6b/configuration_chatglm.py | 54 +-
.../modeling/chatglm2_6b/modeling_chatglm.py | 196 +-
colossalai/shardformer/modeling/gpt2.py | 507 +-
colossalai/shardformer/modeling/jit.py | 3 -
colossalai/shardformer/modeling/llama.py | 117 +-
colossalai/shardformer/modeling/opt.py | 152 +-
colossalai/shardformer/modeling/sam.py | 32 +-
colossalai/shardformer/modeling/t5.py | 189 +-
colossalai/shardformer/modeling/vit.py | 93 +-
colossalai/shardformer/modeling/whisper.py | 227 +-
.../shardformer/policies/auto_policy.py | 228 +-
.../shardformer/policies/base_policy.py | 24 +-
colossalai/shardformer/policies/bert.py | 431 +-
colossalai/shardformer/policies/blip2.py | 496 +-
colossalai/shardformer/policies/bloom.py | 309 +-
colossalai/shardformer/policies/chatglm2.py | 215 +-
colossalai/shardformer/policies/gpt2.py | 214 +-
colossalai/shardformer/policies/llama.py | 145 +-
colossalai/shardformer/policies/opt.py | 210 +-
colossalai/shardformer/policies/sam.py | 236 +-
colossalai/shardformer/policies/t5.py | 426 +-
colossalai/shardformer/policies/vit.py | 185 +-
colossalai/shardformer/policies/whisper.py | 420 +-
colossalai/shardformer/shard/__init__.py | 2 +-
colossalai/shardformer/shard/shard_config.py | 5 +-
colossalai/shardformer/shard/sharder.py | 63 +-
colossalai/tensor/__init__.py | 13 +-
colossalai/tensor/colo_parameter.py | 12 +-
colossalai/tensor/colo_tensor.py | 19 +-
colossalai/tensor/comm_spec.py | 90 +-
colossalai/tensor/d_tensor/__init__.py | 23 +-
colossalai/tensor/d_tensor/api.py | 62 +-
colossalai/tensor/d_tensor/comm_spec.py | 68 +-
colossalai/tensor/d_tensor/layout.py | 12 +-
.../tensor/d_tensor/layout_converter.py | 97 +-
colossalai/tensor/d_tensor/sharding_spec.py | 96 +-
colossalai/tensor/d_tensor/utils.py | 4 +-
colossalai/tensor/param_op_hook.py | 3 +-
colossalai/tensor/shape_consistency.py | 147 +-
colossalai/tensor/sharding_spec.py | 103 +-
colossalai/tensor/utils.py | 35 +-
colossalai/testing/__init__.py | 18 +-
colossalai/testing/comparison.py | 52 +-
colossalai/testing/pytest_wrapper.py | 9 +-
colossalai/testing/random.py | 2 +-
colossalai/testing/utils.py | 15 +-
colossalai/utils/__init__.py | 32 +-
colossalai/utils/common.py | 2 +-
colossalai/utils/cuda.py | 4 +-
colossalai/utils/model/utils.py | 19 +-
colossalai/utils/moe.py | 5 +-
.../multi_tensor_apply/multi_tensor_apply.py | 4 +-
colossalai/utils/rank_recorder/README.md | 8 +-
colossalai/utils/rank_recorder/__init__.py | 2 +-
.../utils/rank_recorder/rank_recorder.py | 59 +-
colossalai/utils/tensor_detector/__init__.py | 2 +-
colossalai/utils/tensor_detector/readme.md | 3 +-
.../utils/tensor_detector/tensor_detector.py | 85 +-
colossalai/utils/timer.py | 15 +-
colossalai/zero/__init__.py | 11 +-
colossalai/zero/gemini/__init__.py | 13 +-
colossalai/zero/gemini/chunk/__init__.py | 2 +-
colossalai/zero/gemini/chunk/chunk.py | 157 +-
colossalai/zero/gemini/chunk/manager.py | 58 +-
colossalai/zero/gemini/chunk/search_utils.py | 24 +-
colossalai/zero/gemini/chunk/utils.py | 28 +-
colossalai/zero/gemini/colo_init_context.py | 69 +-
colossalai/zero/gemini/gemini_ddp.py | 283 +-
colossalai/zero/gemini/gemini_hook.py | 7 +-
colossalai/zero/gemini/gemini_mgr.py | 46 +-
colossalai/zero/gemini/gemini_optimizer.py | 281 +-
.../zero/gemini/memory_tracer/__init__.py | 18 +-
.../memory_tracer/chunk_memstats_collector.py | 4 +-
.../gemini/memory_tracer/memory_monitor.py | 1 +
.../zero/gemini/memory_tracer/memory_stats.py | 11 +-
.../memory_tracer/memstats_collector.py | 15 +-
.../memory_tracer/param_runtime_order.py | 1 -
.../memory_tracer/runtime_mem_tracer.py | 4 +-
.../static_memstats_collector.py | 22 +-
colossalai/zero/gemini/memory_tracer/utils.py | 8 +-
colossalai/zero/gemini/placement_policy.py | 93 +-
colossalai/zero/gemini/utils.py | 27 +-
colossalai/zero/low_level/__init__.py | 2 +-
colossalai/zero/low_level/_utils.py | 27 +-
.../zero/low_level/bookkeeping/__init__.py | 2 +-
.../zero/low_level/bookkeeping/base_store.py | 1 -
.../low_level/bookkeeping/bucket_store.py | 9 +-
.../low_level/bookkeeping/gradient_store.py | 2 -
.../low_level/bookkeeping/parameter_store.py | 1 -
.../low_level/bookkeeping/tensor_bucket.py | 4 +-
colossalai/zero/low_level/low_level_optim.py | 186 +-
colossalai/zero/wrapper.py | 57 +-
examples/community/fp8/mnist/main.py | 37 +-
.../roberta/preprocessing/get_mask.py | 72 +-
.../roberta/preprocessing/sentence_split.py | 59 +-
.../roberta/preprocessing/tokenize_mask.py | 100 +-
.../roberta/pretraining/arguments.py | 93 +-
.../pretraining/bert_dataset_provider.py | 1 -
.../roberta/pretraining/evaluation.py | 42 +-
.../community/roberta/pretraining/loss.py | 5 +-
.../roberta/pretraining/model/bert.py | 135 +-
.../roberta/pretraining/model/deberta_v2.py | 148 +-
.../nvidia_bert_dataset_provider.py | 57 +-
.../roberta/pretraining/pretrain_utils.py | 70 +-
.../roberta/pretraining/run_pretraining.py | 159 +-
.../roberta/pretraining/utils/WandbLog.py | 8 +-
.../roberta/pretraining/utils/exp_util.py | 51 +-
.../roberta/pretraining/utils/global_vars.py | 22 +-
.../roberta/pretraining/utils/logger.py | 14 +-
examples/images/diffusion/README.md | 2 +-
.../images/diffusion/configs/train_ddp.yaml | 2 +-
examples/images/diffusion/ldm/data/base.py | 27 +-
examples/images/diffusion/ldm/data/cifar10.py | 75 +-
.../images/diffusion/ldm/data/imagenet.py | 123 +-
examples/images/diffusion/ldm/data/lsun.py | 116 +-
examples/images/diffusion/ldm/data/teyvat.py | 61 +-
examples/images/diffusion/ldm/lr_scheduler.py | 31 +-
.../diffusion/ldm/models/autoencoder.py | 106 +-
.../ldm/models/diffusion/classifier.py | 145 +-
.../diffusion/ldm/models/diffusion/ddim.py | 344 +-
.../diffusion/ldm/models/diffusion/ddpm.py | 1099 +--
.../models/diffusion/dpm_solver/__init__.py | 2 +-
.../models/diffusion/dpm_solver/dpm_solver.py | 564 +-
.../models/diffusion/dpm_solver/sampler.py | 69 +-
.../diffusion/ldm/models/diffusion/plms.py | 258 +-
.../ldm/models/diffusion/sampling_util.py | 8 +-
.../images/diffusion/ldm/modules/attention.py | 177 +-
.../ldm/modules/diffusionmodules/model.py | 475 +-
.../modules/diffusionmodules/openaimodel.py | 169 +-
.../ldm/modules/diffusionmodules/upscaling.py | 49 +-
.../ldm/modules/diffusionmodules/util.py | 42 +-
.../modules/distributions/distributions.py | 37 +-
examples/images/diffusion/ldm/modules/ema.py | 13 +-
.../diffusion/ldm/modules/encoders/modules.py | 104 +-
.../ldm/modules/image_degradation/bsrgan.py | 209 +-
.../modules/image_degradation/bsrgan_light.py | 174 +-
.../modules/image_degradation/utils_image.py | 278 +-
.../images/diffusion/ldm/modules/midas/api.py | 33 +-
.../ldm/modules/midas/midas/base_model.py | 2 +-
.../ldm/modules/midas/midas/blocks.py | 130 +-
.../ldm/modules/midas/midas/dpt_depth.py | 16 +-
.../ldm/modules/midas/midas/midas_net.py | 7 +-
.../modules/midas/midas/midas_net_custom.py | 87 +-
.../ldm/modules/midas/midas/transforms.py | 50 +-
.../diffusion/ldm/modules/midas/midas/vit.py | 69 +-
.../diffusion/ldm/modules/midas/utils.py | 21 +-
examples/images/diffusion/ldm/util.py | 116 +-
examples/images/diffusion/main.py | 338 +-
.../scripts/download_first_stages.sh | 2 +-
examples/images/diffusion/scripts/img2img.py | 76 +-
examples/images/diffusion/scripts/inpaint.py | 54 +-
examples/images/diffusion/scripts/knn2img.py | 153 +-
.../diffusion/scripts/sample_diffusion.py | 158 +-
.../scripts/tests/test_checkpoint.py | 20 +-
.../diffusion/scripts/tests/test_watermark.py | 8 +-
.../diffusion/scripts/train_searcher.py | 166 +-
examples/images/diffusion/scripts/txt2img.py | 154 +-
examples/images/diffusion/scripts/utils.py | 38 +-
examples/images/diffusion/setup.py | 16 +-
examples/images/diffusion/train_colossalai.sh | 1 -
examples/images/diffusion/train_ddp.sh | 6 +-
examples/images/dreambooth/README.md | 4 +-
examples/images/dreambooth/debug.py | 8 +-
examples/images/dreambooth/inference.py | 4 +-
.../images/dreambooth/train_dreambooth.py | 115 +-
.../dreambooth/train_dreambooth_colossalai.py | 149 +-
.../train_dreambooth_colossalai_lora.py | 142 +-
.../dreambooth/train_dreambooth_inpaint.py | 150 +-
examples/images/resnet/eval.py | 11 +-
examples/images/resnet/requirements.txt | 2 +-
examples/images/resnet/train.py | 99 +-
examples/images/vit/args.py | 96 +-
examples/images/vit/data.py | 16 +-
examples/images/vit/requirements.txt | 2 +-
examples/images/vit/vit_benchmark.py | 51 +-
examples/images/vit/vit_train_demo.py | 132 +-
examples/inference/bench_bloom.py | 18 +-
examples/inference/bench_llama.py | 23 +-
examples/language/bert/benchmark.py | 56 +-
examples/language/bert/benchmark_utils.py | 9 +-
examples/language/bert/data.py | 16 +-
examples/language/bert/finetune.py | 126 +-
.../gpt/experiments/auto_offload/model_zoo.py | 28 +-
.../experiments/auto_offload/requirements.txt | 2 +-
.../auto_offload/train_gpt_offload.py | 37 +-
.../auto_parallel/auto_parallel_with_gpt.py | 21 +-
.../experiments/auto_parallel/gpt_modules.py | 24 +-
.../pipeline_parallel/model_zoo.py | 33 +-
.../pipeline_parallel/train_gpt_pp.py | 79 +-
.../language/gpt/gemini/commons/model_zoo.py | 33 +-
examples/language/gpt/gemini/commons/utils.py | 13 +-
.../language/gpt/gemini/train_gpt_demo.py | 44 +-
.../language/gpt/hybridparallelism/data.py | 16 +-
.../gpt/hybridparallelism/finetune.py | 130 +-
.../titans/configs/gpt2_small_zero3_pp1d.py | 8 +-
.../gpt/titans/configs/gpt3_zero3_pp1d.py | 8 +-
.../language/gpt/titans/dataset/webtext.py | 15 +-
examples/language/gpt/titans/model/embed.py | 184 +-
examples/language/gpt/titans/model/gpt1d.py | 252 +-
.../gpt/titans/model/pipeline_gpt1d.py | 321 +-
examples/language/gpt/titans/train_gpt.py | 91 +-
examples/language/llama2/attn.py | 9 +-
examples/language/llama2/benchmark.py | 214 +-
examples/language/llama2/data_utils.py | 75 +-
examples/language/llama2/finetune.py | 256 +-
examples/language/llama2/model_utils.py | 8 +-
.../language/llama2/performance_evaluator.py | 31 +-
examples/language/llama2/pretrain.py | 293 +-
examples/language/opt/args.py | 76 +-
examples/language/opt/data.py | 29 +-
examples/language/opt/opt_benchmark.py | 18 +-
examples/language/opt/opt_train_demo.py | 76 +-
examples/language/opt/run_benchmark.sh | 2 +-
.../palm_pytorch/autoregressive_wrapper.py | 2 -
.../palm/palm_pytorch/palm_pytorch.py | 28 +-
examples/language/palm/train.py | 48 +-
examples/tutorial/README.md | 2 +-
.../auto_parallel/auto_ckpt_batchsize_test.py | 16 +-
.../auto_parallel/auto_ckpt_solver_test.py | 34 +-
.../auto_parallel_with_resnet.py | 9 +-
.../tutorial/auto_parallel/bench_utils.py | 64 +-
examples/tutorial/auto_parallel/setup.py | 12 +-
examples/tutorial/download_cifar10.py | 4 +-
examples/tutorial/hybrid_parallel/config.py | 6 +-
examples/tutorial/hybrid_parallel/train.py | 47 +-
.../tutorial/large_batch_optimizer/train.py | 32 +-
.../tutorial/new_api/cifar_resnet/eval.py | 11 +-
.../tutorial/new_api/cifar_resnet/train.py | 101 +-
examples/tutorial/new_api/cifar_vit/train.py | 124 +-
examples/tutorial/new_api/glue_bert/data.py | 16 +-
.../tutorial/new_api/glue_bert/finetune.py | 80 +-
examples/tutorial/opt/inference/batch.py | 29 +-
.../opt/inference/benchmark/locustfile.py | 9 +-
examples/tutorial/opt/inference/cache.py | 4 +-
.../tutorial/opt/inference/opt_fastapi.py | 101 +-
examples/tutorial/opt/inference/opt_server.py | 119 +-
.../script/process-opt-175b/README.md | 1 -
.../script/process-opt-175b/convert_ckpt.py | 39 +-
.../script/process-opt-175b/flat-meta.json | 6945 ++++++++++++++++-
.../inference/script/processing_ckpt_66b.py | 24 +-
examples/tutorial/opt/opt/colossalai_zero.py | 8 +-
examples/tutorial/opt/opt/context.py | 2 +-
examples/tutorial/opt/opt/run_clm.py | 154 +-
examples/tutorial/sequence_parallel/config.py | 6 +-
.../sequence_parallel/data/__init__.py | 52 +-
.../sequence_parallel/data/bert_helper.py | 45 +-
.../data/datasets/bert_dataset.py | 153 +-
.../data/datasets/blendable_dataset.py | 18 +-
.../data/datasets/builder.py | 134 +-
.../data/datasets/data_samplers.py | 85 +-
.../data/datasets/dataset_utils.py | 274 +-
.../data/datasets/helpers.cpp | 1163 ++-
.../data/datasets/ict_dataset.py | 67 +-
.../data/datasets/indexed_dataset.py | 147 +-
.../datasets/test/test_indexed_dataset.py | 59 +-
.../data/dummy_dataloader.py | 55 +-
.../data/tokenizer/__init__.py | 1 -
.../data/tokenizer/bert_tokenization.py | 67 +-
.../data/tokenizer/tokenizer.py | 69 +-
.../sequence_parallel/loss_func/bert_loss.py | 5 -
.../loss_func/cross_entropy.py | 5 +-
.../sequence_parallel/loss_func/utils.py | 17 +-
.../lr_scheduler/annealing_lr.py | 104 +-
.../sequence_parallel/model/__init__.py | 2 -
.../tutorial/sequence_parallel/model/bert.py | 126 +-
.../model/layers/__init__.py | 2 +-
.../model/layers/bert_layer.py | 34 +-
.../sequence_parallel/model/layers/dropout.py | 4 +-
.../model/layers/embedding.py | 35 +-
.../sequence_parallel/model/layers/head.py | 5 -
.../model/layers/init_method.py | 4 +-
.../sequence_parallel/model/layers/linear.py | 25 +-
.../sequence_parallel/model/layers/mlp.py | 20 +-
.../sequence_parallel/model/layers/pooler.py | 1 +
.../model/layers/preprocess.py | 10 +-
examples/tutorial/sequence_parallel/train.py | 115 +-
op_builder/__init__.py | 29 +-
op_builder/builder.py | 74 +-
op_builder/cpu_adam.py | 24 +-
op_builder/fused_optim.py | 25 +-
op_builder/layernorm.py | 12 +-
op_builder/moe.py | 20 +-
op_builder/multi_head_attn.py | 32 +-
op_builder/scaled_masked_softmax.py | 27 +-
.../scaled_upper_triangle_masked_softmax.py | 24 +-
op_builder/utils.py | 61 +-
setup.py | 133 +-
tests/components_to_test/__init__.py | 16 +-
tests/components_to_test/albert.py | 51 +-
tests/components_to_test/beit.py | 30 +-
tests/components_to_test/bert.py | 61 +-
tests/components_to_test/gpt2.py | 60 +-
.../components_to_test/hanging_param_model.py | 5 +-
tests/components_to_test/inline_op_model.py | 7 +-
tests/components_to_test/nested_model.py | 6 +-
tests/components_to_test/registry.py | 3 +-
.../repeated_computed_layers.py | 4 +-
tests/components_to_test/resnet.py | 17 +-
tests/components_to_test/simple_net.py | 5 +-
.../utils/dummy_data_generator.py | 1 -
tests/kit/model_zoo/__init__.py | 2 +-
tests/kit/model_zoo/diffusers/diffusers.py | 76 +-
tests/kit/model_zoo/registry.py | 21 +-
tests/kit/model_zoo/timm/timm.py | 316 +-
tests/kit/model_zoo/torchaudio/torchaudio.py | 135 +-
tests/kit/model_zoo/torchrec/torchrec.py | 126 +-
.../kit/model_zoo/torchvision/torchvision.py | 202 +-
tests/kit/model_zoo/transformers/albert.py | 98 +-
tests/kit/model_zoo/transformers/bert.py | 453 +-
tests/kit/model_zoo/transformers/blip2.py | 30 +-
tests/kit/model_zoo/transformers/bloom.py | 101 +-
tests/kit/model_zoo/transformers/chatglm2.py | 60 +-
tests/kit/model_zoo/transformers/gpt.py | 124 +-
tests/kit/model_zoo/transformers/llama.py | 63 +-
tests/kit/model_zoo/transformers/opt.py | 59 +-
tests/kit/model_zoo/transformers/sam.py | 24 +-
tests/kit/model_zoo/transformers/t5.py | 46 +-
tests/kit/model_zoo/transformers/vit.py | 48 +-
tests/kit/model_zoo/transformers/whisper.py | 52 +-
.../test_fx/test_bias_addition.py | 58 +-
tests/test_analyzer/test_fx/test_mod_dir.py | 35 +-
.../test_analyzer/test_fx/test_nested_ckpt.py | 7 +-
.../test_analyzer/test_fx/test_shape_prop.py | 20 +-
.../test_fx/test_symbolic_profile.py | 11 +-
.../test_subclasses/test_aten.py | 45 +-
.../test_subclasses/test_flop_tensor.py | 49 +-
.../test_subclasses/test_meta_mode.py | 21 +-
.../test_C_solver_consistency.py | 19 +-
.../test_ckpt_torchvision.py | 37 +-
.../test_ckpt_solvers/test_linearize.py | 26 +-
.../test_offload/model_utils.py | 51 +-
.../test_offload/test_perf.py | 58 +-
.../test_offload/test_solver.py | 15 +-
.../test_pass/test_node_converting_pass.py | 12 +-
.../test_size_value_converting_pass.py | 12 +-
.../test_bias_addition_forward.py | 26 +-
.../test_tensor_shard/test_broadcast.py | 14 +-
.../test_tensor_shard/test_checkpoint.py | 26 +-
.../test_compatibility_with_ddp.py | 42 +-
.../test_compatibility_with_gemini.py | 50 +-
.../test_find_repeat_block.py | 19 +-
.../test_tensor_shard/test_gpt/gpt_modules.py | 22 +-
.../test_gpt/test_runtime_with_gpt_modules.py | 68 +-
.../test_gpt/test_solver_with_gpt_module.py | 25 +-
.../test_liveness_analysis.py | 11 +-
.../test_metainfo/test_activation_metainfo.py | 39 +-
.../test_binary_elementwise_metainfo.py | 24 +-
.../test_metainfo/test_conv_metainfo.py | 51 +-
.../test_metainfo/test_embedding_metainfo.py | 20 +-
.../test_metainfo/test_linear_metainfo.py | 51 +-
.../test_metainfo/test_matmul_metainfo.py | 47 +-
.../test_metainfo/test_norm_metainfo.py | 55 +-
.../test_metainfo/test_pooling_metainfo.py | 42 +-
.../test_metainfo/test_tensor_metainfo.py | 23 +-
.../test_metainfo/test_where_metainfo.py | 22 +-
.../test_tensor_shard/test_metainfo/utils.py | 87 +-
.../test_node_handler/test_addbmm_handler.py | 174 +-
.../test_node_handler/test_addmm_handler.py | 94 +-
.../test_batch_norm_handler.py | 60 +-
.../test_bias_linear_function_node.py | 107 +-
.../test_bias_linear_module_node.py | 105 +-
.../test_binary_elementwise_handler.py | 151 +-
.../test_node_handler/test_bmm_handler.py | 136 +-
.../test_node_handler/test_conv_handler.py | 199 +-
.../test_default_reshape_handler.py | 41 +-
.../test_embedding_handler.py | 198 +-
.../test_node_handler/test_getattr_handler.py | 37 +-
.../test_node_handler/test_getitem_handler.py | 94 +-
.../test_layer_norm_handler.py | 62 +-
.../test_node_handler/test_linear_handler.py | 229 +-
.../test_node_handler/test_matmul_handler.py | 79 +-
.../test_norm_pooling_handler.py | 26 +-
.../test_node_handler/test_output_handler.py | 25 +-
.../test_permute_and_transpose_handler.py | 365 +-
.../test_placeholder_handler.py | 37 +-
.../test_node_handler/test_shard_option.py | 74 +-
.../test_node_handler/test_softmax_handler.py | 169 +-
.../test_node_handler/test_split_handler.py | 263 +-
.../test_node_handler/test_sum_handler.py | 255 +-
.../test_tensor_constructor.py | 17 +-
.../test_unary_element_wise_handler.py | 41 +-
.../test_node_handler/test_view_handler.py | 261 +-
.../test_node_handler/test_where_handler.py | 55 +-
.../test_node_handler/utils.py | 102 +-
.../test_solver_with_resnet_v2.py | 21 +-
.../benchmark_autochunk_alphafold.py | 5 +-
.../test_autochunk_alphafold_utils.py | 7 +-
.../test_autochunk_evoformer_block.py | 53 +-
.../test_autochunk_evoformer_stack.py | 47 +-
.../test_autochunk_extramsa_block.py | 43 +-
.../benchmark_autochunk_diffuser.py | 17 +-
.../test_autochunk_diffuser_utils.py | 10 +-
.../test_autochunk_unet.py | 2 +
.../benchmark_autochunk_transformer.py | 17 +-
.../test_autochunk_gpt.py | 23 +-
.../test_autochunk_transformer_utils.py | 34 +-
.../test_autochunk_vit/test_autochunk_vit.py | 3 +-
.../test_autochunk_vit_utils.py | 8 +-
tests/test_booster/test_accelerator.py | 2 +-
.../test_mixed_precision/test_fp16_torch.py | 8 +-
.../test_plugin/test_3d_plugin.py | 31 +-
.../test_plugin/test_dp_plugin_base.py | 10 +-
.../test_plugin/test_gemini_plugin.py | 82 +-
.../test_plugin/test_low_level_zero_plugin.py | 22 +-
.../test_plugin/test_torch_ddp_plugin.py | 14 +-
.../test_plugin/test_torch_fsdp_plugin.py | 22 +-
.../test_gemini_checkpoint_io.py | 77 +-
.../test_gemini_torch_compability.py | 47 +-
.../test_general_checkpoint_io.py | 33 +-
...st_hybrid_parallel_plugin_checkpoint_io.py | 92 +-
.../test_low_level_zero_checkpoint_io.py | 15 +-
.../test_plugins_huggingface_compatibility.py | 31 +-
.../test_torch_ddp_checkpoint_io.py | 14 +-
.../test_torch_fsdp_checkpoint_io.py | 19 +-
tests/test_checkpoint_io/utils.py | 2 +-
.../test_cluster/test_device_mesh_manager.py | 8 +-
tests/test_cluster/test_process_group_mesh.py | 56 +-
tests/test_config/sample_config.py | 20 +-
tests/test_config/test_load_config.py | 13 +-
tests/test_device/test_alpha_beta.py | 6 +-
tests/test_device/test_device_mesh.py | 22 +-
tests/test_device/test_extract_alpha_beta.py | 6 +-
tests/test_device/test_init_logical_pg.py | 4 +-
.../test_search_logical_device_mesh.py | 6 +-
.../test_activation_checkpoint_codegen.py | 61 +-
...st_nested_activation_checkpoint_codegen.py | 75 +-
.../test_codegen/test_offload_codegen.py | 76 +-
tests/test_fx/test_coloproxy.py | 9 +-
tests/test_fx/test_comm_size_compute.py | 5 +-
tests/test_fx/test_graph_manipulation.py | 10 +-
tests/test_fx/test_meta/test_aten.py | 45 +-
tests/test_fx/test_meta/test_backward.py | 31 +-
tests/test_fx/test_meta/test_meta_trace.py | 31 +-
tests/test_fx/test_meta_info_prop.py | 14 +-
tests/test_fx/test_parallel_1d.py | 7 +-
.../test_pipeline/test_hf_model/hf_utils.py | 25 +-
.../test_hf_model/test_albert.py | 18 +-
.../test_pipeline/test_hf_model/test_bert.py | 12 +-
.../test_pipeline/test_hf_model/test_gpt.py | 6 +-
.../test_pipeline/test_hf_model/test_opt.py | 4 +-
.../test_pipeline/test_hf_model/test_t5.py | 4 +-
.../test_timm_model/test_timm.py | 17 +-
.../test_timm_model/timm_utils.py | 13 +-
.../test_pipeline/test_topo/test_topo.py | 9 +-
.../test_pipeline/test_topo/topo_utils.py | 15 +-
.../test_torchvision/test_torchvision.py | 19 +-
tests/test_fx/test_pipeline_passes.py | 7 +-
tests/test_fx/test_profiler/gpt_utils.py | 34 +-
.../test_profiler_meta_info_prop.py | 81 +-
.../test_activation_checkpoint_annotation.py | 15 +-
.../test_tracer/test_bias_addition_module.py | 23 +-
.../test_fx/test_tracer/test_control_flow.py | 19 +-
.../test_tracer/test_functional_conv.py | 2 +-
.../test_hf_model/hf_tracer_utils.py | 7 +-
.../test_hf_model/test_hf_albert.py | 6 +-
.../test_tracer/test_hf_model/test_hf_bert.py | 8 +-
.../test_hf_model/test_hf_diffuser.py | 10 +-
.../test_tracer/test_hf_model/test_hf_gpt.py | 10 +-
.../test_tracer/test_hf_model/test_hf_opt.py | 8 +-
.../test_tracer/test_hf_model/test_hf_t5.py | 10 +-
.../test_tracer/test_patched_module.py | 515 +-
tests/test_fx/test_tracer/test_patched_op.py | 36 +-
.../test_timm_model/test_timm_model.py | 15 +-
.../test_torchaudio_model.py | 10 +-
.../test_torchaudio_model/torchaudio_utils.py | 7 +-
.../test_torchrec_model/test_deepfm_model.py | 22 +-
.../test_torchrec_model/test_dlrm_model.py | 24 +-
.../test_torchvision_model.py | 8 +-
tests/test_infer/_utils.py | 26 +-
tests/test_infer/test_bloom_infer.py | 27 +-
tests/test_infer/test_infer_engine.py | 36 +-
tests/test_infer/test_kvcache_manager.py | 39 +-
tests/test_infer/test_llama_infer.py | 31 +-
.../test_infer_ops/cuda/test_vllm_rmsnorm.py | 14 +-
.../cuda/test_vllm_rotary_embedding.py | 47 +-
tests/test_infer_ops/triton/kernel_utils.py | 9 +-
.../triton/test_bloom_context_attention.py | 20 +-
.../triton/test_copy_kv_dest.py | 18 +-
.../triton/test_layernorm_triton.py | 21 +-
.../triton/test_llama_context_attention.py | 20 +-
.../triton/test_rotary_embedding.py | 21 +-
.../triton/test_self_attention_nonfusion.py | 104 +-
tests/test_infer_ops/triton/test_softmax.py | 25 +-
.../triton/test_token_attn_1.py | 18 +-
.../triton/test_token_attn_2.py | 26 +-
.../triton/test_token_attn_fwd.py | 14 +-
.../triton/test_token_softmax.py | 12 +-
tests/test_lazy/lazy_init_utils.py | 38 +-
tests/test_lazy/test_models.py | 18 +-
tests/test_legacy/test_amp/test_naive_fp16.py | 15 +-
tests/test_legacy/test_amp/test_torch_fp16.py | 16 +-
.../test_comm/test_boardcast_send_recv_v2.py | 16 +-
tests/test_legacy/test_comm/test_comm.py | 24 +-
.../test_comm/test_object_list_p2p.py | 16 +-
.../test_comm/test_object_list_p2p_v2.py | 12 +-
.../test_context/configs/parallel_2d_init.py | 2 +-
.../configs/parallel_2p5d_init.py | 2 +-
.../test_context/configs/parallel_3d_init.py | 2 +-
.../test_context/test_hybrid_parallel.py | 32 +-
.../test_data/test_cifar10_dataset.py | 5 +-
.../test_data/test_data_parallel_sampler.py | 28 +-
.../test_deterministic_dataloader.py | 19 +-
tests/test_legacy/test_engine/test_engine.py | 34 +-
.../test_engine/test_gradient_accumluation.py | 57 +-
.../test_1d/checks_1d/check_layer_1d.py | 44 +-
.../test_layers/test_1d/test_1d.py | 8 +-
.../test_2d/checks_2d/check_layer_2d.py | 48 +-
.../test_2d/checks_2d/check_operation_2d.py | 107 +-
.../test_layers/test_2d/test_2d.py | 8 +-
.../test_2p5d/checks_2p5d/check_layer_2p5d.py | 70 +-
.../checks_2p5d/check_operation_2p5d.py | 113 +-
.../test_layers/test_2p5d/test_2p5d.py | 14 +-
.../test_3d/checks_3d/check_layer_3d.py | 241 +-
.../test_layers/test_3d/test_3d.py | 6 +-
.../test_layers/test_cache_embedding.py | 162 +-
.../checks_seq/check_layer_seq.py | 7 +-
.../test_sequence/test_sequence.py | 32 +-
.../test_pipeline/rpc_test_utils.py | 71 +-
.../test_pipeline/test_cuda_rpc_chimera.py | 18 +-
.../test_pipeline/test_cuda_rpc_optimizer.py | 21 +-
.../test_pipeline/test_cuda_rpc_pipeline.py | 17 +-
.../test_cuda_rpc_value_correctness.py | 19 +-
.../test_pipeline/test_middleware_1f1b.py | 33 +-
.../test_pipeline/test_pipelinable.py | 5 +-
.../test_pipeline_process_group.py | 20 +-
.../test_tensor/common_utils/_utils.py | 19 +-
.../test_tensor/core/test_dist_spec_mgr.py | 6 +-
.../test_legacy/test_tensor/test_parameter.py | 9 +-
.../test_trainer/test_pipeline/test_p2p.py | 30 +-
.../test_pipeline/test_pipeline_schedule.py | 25 +-
.../test_trainer_with_non_pipe_schedule.py | 34 +-
.../test_trainer_with_pipe_schedule.py | 55 +-
.../test_activation_checkpointing.py | 14 +-
.../test_checkpoint/test_checkpoint_1d.py | 4 +-
.../test_checkpoint/test_checkpoint_2d.py | 4 +-
.../test_checkpoint/test_checkpoint_2p5d.py | 4 +-
.../test_checkpoint/test_checkpoint_3d.py | 4 +-
tests/test_legacy/test_utils/test_memory.py | 4 +-
.../test_utils/test_norm_gradient_clipping.py | 26 +-
tests/test_legacy/test_zero/test_commons.py | 18 +-
tests/test_moe/test_grad_handler.py | 6 +-
tests/test_moe/test_kernel.py | 14 +-
tests/test_moe/test_moe_checkpoint.py | 10 +-
tests/test_moe/test_moe_colo_init.py | 11 +-
tests/test_moe/test_moe_group.py | 6 +-
tests/test_moe/test_moe_zero_init.py | 46 +-
tests/test_moe/test_moe_zero_model.py | 12 +-
tests/test_moe/test_moe_zero_optim.py | 43 +-
tests/test_optimizer/test_adam_kernel.py | 92 +-
tests/test_optimizer/test_adam_optim.py | 33 +-
tests/test_optimizer/test_nvme.py | 22 +-
tests/test_pipeline/test_p2p_communication.py | 12 +-
.../test_t5_pipeline_utils.py | 42 +-
.../test_whisper_pipeline_utils.py | 46 +-
.../test_schedule/test_interleaved.py | 44 +-
.../test_schedule/test_oneF_oneB.py | 28 +-
.../test_pipeline_schedule_utils.py | 24 +-
tests/test_pipeline/test_stage_manager.py | 4 +-
.../test_layer/test_dist_crossentropy.py | 15 +-
.../test_layer/test_dropout.py | 4 +-
.../test_layer/test_embedding.py | 6 +-
.../test_gpt2_qkv_fused_linear_1d.py | 26 +-
.../test_layer/test_layernorm.py | 6 +-
.../test_layer/test_linear_1d.py | 54 +-
.../test_layer/test_qkv_fused_linear_1d.py | 15 +-
.../test_vocab_parallel_embedding_1d.py | 12 +-
tests/test_shardformer/test_model/_utils.py | 212 +-
.../test_model/test_shard_bert.py | 202 +-
.../test_model/test_shard_blip2.py | 39 +-
.../test_model/test_shard_bloom.py | 198 +-
.../test_model/test_shard_chatglm2.py | 214 +-
.../test_model/test_shard_gpt2.py | 228 +-
.../test_model/test_shard_llama.py | 231 +-
.../test_model/test_shard_opt.py | 213 +-
.../test_model/test_shard_sam.py | 31 +-
.../test_model/test_shard_t5.py | 205 +-
.../test_model/test_shard_vit.py | 209 +-
.../test_model/test_shard_whisper.py | 202 +-
tests/test_shardformer/test_shard_utils.py | 1 -
tests/test_shardformer/test_with_torch_ddp.py | 7 +-
tests/test_tensor/test_comm_spec_apply.py | 19 +-
.../test_dtensor/test_comm_spec.py | 30 +-
.../test_tensor/test_dtensor/test_dtensor.py | 17 +-
.../test_dtensor_sharding_spec.py | 7 +-
.../test_dtensor/test_layout_converter.py | 36 +-
tests/test_tensor/test_mix_gather.py | 150 +-
tests/test_tensor/test_shape_consistency.py | 63 +-
.../test_shape_consistency_apply.py | 6 +-
tests/test_tensor/test_sharding_spec.py | 4 +-
tests/test_utils/test_flash_attention.py | 34 +-
.../test_zero/test_gemini/test_chunk_mgrv2.py | 35 +-
tests/test_zero/test_gemini/test_chunkv2.py | 42 +-
tests/test_zero/test_gemini/test_fwd_bwd.py | 41 +-
.../test_gemini/test_gemini_use_rmt.py | 45 +-
tests/test_zero/test_gemini/test_grad_clip.py | 63 +-
tests/test_zero/test_gemini/test_inference.py | 37 +-
tests/test_zero/test_gemini/test_optim.py | 110 +-
.../test_gemini/test_runtime_mem_tracer.py | 8 +-
tests/test_zero/test_gemini/test_search.py | 36 +-
.../test_gemini/test_zeroddp_state_dict.py | 51 +-
.../test_gemini/test_zerooptim_state_dict.py | 44 +-
.../test_zero/test_low_level/test_grad_acc.py | 29 +-
.../test_zero/test_low_level/test_zero1_2.py | 29 +-
.../test_low_level/test_zero_ckpt.py | 19 +-
1268 files changed, 50252 insertions(+), 38659 deletions(-)
delete mode 100644 .flake8
delete mode 100644 .style.yapf
diff --git a/.flake8 b/.flake8
deleted file mode 100644
index 229856aa4366..000000000000
--- a/.flake8
+++ /dev/null
@@ -1,22 +0,0 @@
-[flake8]
-ignore =
- ;W503 line break before binary operator
- W503,
- ;E203 whitespace before ':'
- E203,
-
-; exclude file
-exclude =
- .tox,
- .git,
- __pycache__,
- build,
- dist,
- *.pyc,
- *.egg-info,
- .cache,
- .eggs
-
-max-line-length = 120
-
-per-file-ignores = __init__.py:F401
diff --git a/.github/workflows/scripts/check_doc_i18n.py b/.github/workflows/scripts/check_doc_i18n.py
index 1aa7283e9e52..1e7f0c33a785 100644
--- a/.github/workflows/scripts/check_doc_i18n.py
+++ b/.github/workflows/scripts/check_doc_i18n.py
@@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2):
# If the corresponding item doesn't exist in the second directory, the directories are different
if not os.path.exists(item_path2):
- print(f'Found mismatch: {item_path1}, {item_path2}')
+ print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# If the corresponding item is a directory, we compare the two directories recursively
if os.path.isdir(item_path1) and os.path.isdir(item_path2):
if not compare_dirs(item_path1, item_path2):
- print(f'Found mismatch: {item_path1}, {item_path2}')
+ print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# both are files
@@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2):
# If the corresponding item is not a file or a directory, the directories are different
else:
- print(f'Found mismatch: {item_path1}, {item_path2}')
+ print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# If all items are the same, the directories are the same
return True
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.")
+ parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.")
args = parser.parse_args()
i18n_folders = os.listdir(args.directory)
@@ -56,7 +56,7 @@ def compare_dirs(dir1, dir2):
for i in range(1, len(i18n_folders)):
dir1 = i18n_folders[0]
dir2 = i18n_folders[i]
- print(f'comparing {dir1} vs {dir2}')
+ print(f"comparing {dir1} vs {dir2}")
match = compare_dirs(i18n_folders[0], i18n_folders[i])
if not match:
diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
index 5bec96187e0c..91778f692cc6 100644
--- a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
+++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
@@ -4,7 +4,7 @@
def check_inputs(input_list):
for path in input_list:
- real_path = os.path.join('examples', path)
+ real_path = os.path.join("examples", path)
if not os.path.exists(real_path):
return False
return True
@@ -12,16 +12,16 @@ def check_inputs(input_list):
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
+ parser.add_argument("-f", "--fileNameList", type=str, help="List of file names")
args = parser.parse_args()
name_list = args.fileNameList.split(",")
is_correct = check_inputs(name_list)
if is_correct:
- print('success')
+ print("success")
else:
- print('failure')
+ print("failure")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py
index 83eff644e315..95a3d24c9a78 100644
--- a/.github/workflows/scripts/example_checks/check_example_weekly.py
+++ b/.github/workflows/scripts/example_checks/check_example_weekly.py
@@ -17,21 +17,21 @@ def show_files(path, all_files):
def join(input_list, sep=None):
- return (sep or ' ').join(input_list)
+ return (sep or " ").join(input_list)
def main():
- contents = show_files('examples/', [])
+ contents = show_files("examples/", [])
all_loc = []
for file_loc in contents:
- split_loc = file_loc.split('/')
+ split_loc = file_loc.split("/")
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
if len(split_loc) >= 4:
- re_loc = '/'.join(split_loc[1:3])
+ re_loc = "/".join(split_loc[1:3])
if re_loc not in all_loc:
all_loc.append(re_loc)
print(all_loc)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py
index c69d95a552e9..95f671dfb32b 100644
--- a/.github/workflows/scripts/example_checks/detect_changed_example.py
+++ b/.github/workflows/scripts/example_checks/detect_changed_example.py
@@ -3,7 +3,7 @@
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
+ parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files")
args = parser.parse_args()
name_list = args.fileNameList.split(":")
folder_need_check = set()
@@ -15,10 +15,10 @@ def main():
# - application
# - file
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
- folder_need_check.add('/'.join(loc.split("/")[1:3]))
+ folder_need_check.add("/".join(loc.split("/")[1:3]))
# Output the result using print. Then the shell can get the values.
print(list(folder_need_check))
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
index 2884e38dd3dd..412b14c7b283 100644
--- a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
+++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
@@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]:
# prepare header
headers = {
- 'Authorization': f'Bearer {github_token}',
- 'Accept': 'application/vnd.github+json',
- 'X-GitHub-Api-Version': '2022-11-28'
+ "Authorization": f"Bearer {github_token}",
+ "Accept": "application/vnd.github+json",
+ "X-GitHub-Api-Version": "2022-11-28",
}
res = requests.get(url, headers=headers).json()
repo_list = []
for item in res:
- repo_list.append(item['name'])
+ repo_list.append(item["name"])
return repo_list
@@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
"""
# prepare header
headers = {
- 'Authorization': f'Bearer {github_token}',
- 'Accept': 'application/vnd.github+json',
- 'X-GitHub-Api-Version': '2022-11-28'
+ "Authorization": f"Bearer {github_token}",
+ "Accept": "application/vnd.github+json",
+ "X-GitHub-Api-Version": "2022-11-28",
}
user_engagement_count = {}
@@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
# do pagination to the API
page = 1
while True:
- comment_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}'
+ comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}"
comment_response = requests.get(comment_api, headers=headers).json()
if len(comment_response) == 0:
break
else:
for item in comment_response:
- comment_author_relationship = item['author_association']
- if comment_author_relationship != 'MEMBER':
+ comment_author_relationship = item["author_association"]
+ if comment_author_relationship != "MEMBER":
# if the comment is not made by our member
# we don't count this comment towards user engagement
continue
- issue_id = item['issue_url'].split('/')[-1]
- issue_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}'
+ issue_id = item["issue_url"].split("/")[-1]
+ issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}"
issue_response = requests.get(issue_api, headers=headers).json()
- issue_author_relationship = issue_response['author_association']
+ issue_author_relationship = issue_response["author_association"]
- if issue_author_relationship != 'MEMBER':
+ if issue_author_relationship != "MEMBER":
# this means that the issue/PR is not created by our own people
# any comments in this issue/PR by our member will be counted towards the leaderboard
- member_name = item['user']['login']
+ member_name = item["user"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
@@ -153,7 +153,7 @@ def _generate_discussion_query(num, cursor: str = None):
if cursor is None:
offset_str = ""
else:
- offset_str = f", after: \"{cursor}\""
+ offset_str = f', after: "{cursor}"'
query = f"""
{{
repository(owner: "{org_name}", name: "{repo_name}"){{
@@ -182,7 +182,7 @@ def _generate_comment_reply_count_for_discussion(discussion_number, num, cursor:
if cursor is None:
offset_str = ""
else:
- offset_str = f", before: \"{cursor}\""
+ offset_str = f', before: "{cursor}"'
query = f"""
{{
repository(owner: "{org_name}", name: "{repo_name}"){{
@@ -220,8 +220,8 @@ def _generate_comment_reply_count_for_discussion(discussion_number, num, cursor:
# a utility function to make call to Github GraphQL API
def _call_graphql_api(query):
headers = {"Authorization": f"Bearer {github_token}"}
- json_data = {'query': query}
- response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers)
+ json_data = {"query": query}
+ response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers)
data = response.json()
return data
@@ -234,21 +234,21 @@ def _call_graphql_api(query):
data = _call_graphql_api(query)
found_discussion_out_of_time_range = False
- edges = data['data']['repository']['discussions']['edges']
+ edges = data["data"]["repository"]["discussions"]["edges"]
if len(edges) == 0:
break
else:
# keep the discussion whose author is not a member
for edge in edges:
# print the discussion title
- discussion = edge['node']
- discussion_updated_at = str2datetime(discussion['updatedAt'])
+ discussion = edge["node"]
+ discussion_updated_at = str2datetime(discussion["updatedAt"])
# check if the updatedAt is within the last 7 days
# if yes, add it to discussion_numbers
if discussion_updated_at > since:
- if discussion['authorAssociation'] != 'MEMBER':
- discussion_numbers.append(discussion['number'])
+ if discussion["authorAssociation"] != "MEMBER":
+ discussion_numbers.append(discussion["number"])
else:
found_discussion_out_of_time_range = True
@@ -256,7 +256,7 @@ def _call_graphql_api(query):
break
else:
# update cursor
- cursor = edges[-1]['cursor']
+ cursor = edges[-1]["cursor"]
# get the discussion comments and replies made by our member
user_engagement_count = {}
@@ -269,42 +269,42 @@ def _call_graphql_api(query):
data = _call_graphql_api(query)
# get the comments
- edges = data['data']['repository']['discussion']['comments']['edges']
+ edges = data["data"]["repository"]["discussion"]["comments"]["edges"]
# update the cursor
if len(edges) == 0:
break
else:
# update cursor for pagination
- cursor = edges[-1]['cursor']
+ cursor = edges[-1]["cursor"]
for edge in edges:
- comment = edge['node']
- if comment['authorAssociation'] == 'MEMBER':
+ comment = edge["node"]
+ if comment["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days
# if yes, add it to user_engagement_count
- comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
+ comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if comment_updated_at > since:
- member_name = comment['author']['login']
+ member_name = comment["author"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
else:
user_engagement_count[member_name] = 1
# get the replies
- reply_edges = comment['replies']['edges']
+ reply_edges = comment["replies"]["edges"]
if len(reply_edges) == 0:
continue
else:
for reply_edge in reply_edges:
- reply = reply_edge['node']
- if reply['authorAssociation'] == 'MEMBER':
+ reply = reply_edge["node"]
+ if reply["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days
# if yes, add it to discussion_numbers
- reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
+ reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if reply_updated_at > since:
- member_name = reply['author']['login']
+ member_name = reply["author"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
else:
@@ -312,7 +312,9 @@ def _call_graphql_api(query):
return user_engagement_count
-def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, repo_list: List[str], output_path: str) -> bool:
+def generate_user_engagement_leaderboard_image(
+ github_token: str, org_name: str, repo_list: List[str], output_path: str
+) -> bool:
"""
Generate the user engagement leaderboard image for stats within the last 7 days
@@ -335,16 +337,19 @@ def _update_count(counter):
else:
total_engagement_count[name] = count
-
for repo_name in repo_list:
print(f"Fetching user engagement count for {repo_name}/{repo_name}")
- issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str)
- discussion_engagement_count = get_discussion_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime)
+ issue_pr_engagement_count = get_issue_pull_request_comments(
+ github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str
+ )
+ discussion_engagement_count = get_discussion_comments(
+ github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime
+ )
# update the total engagement count
_update_count(issue_pr_engagement_count)
_update_count(discussion_engagement_count)
-
+
# prepare the data for plotting
x = []
y = []
@@ -363,7 +368,7 @@ def _update_count(counter):
# plot the leaderboard
xlabel = f"Number of Comments made (since {start_datetime_str})"
ylabel = "Member"
- title = 'Active User Engagement Leaderboard'
+ title = "Active User Engagement Leaderboard"
plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True
else:
@@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
"""
# request to the Github API to get the users who have contributed in the last 7 days
headers = {
- 'Authorization': f'Bearer {github_token}',
- 'Accept': 'application/vnd.github+json',
- 'X-GitHub-Api-Version': '2022-11-28'
+ "Authorization": f"Bearer {github_token}",
+ "Accept": "application/vnd.github+json",
+ "X-GitHub-Api-Version": "2022-11-28",
}
counter = Counter()
start_datetime = get_utc_time_one_week_ago()
def _get_url(org_name, repo_name, page):
- return f'https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed'
+ return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed"
def _iterate_by_page(org_name, repo_name):
page = 1
@@ -415,8 +420,8 @@ def _iterate_by_page(org_name, repo_name):
# count the pull request and author from response
for pr_data in response:
- merged_at = pr_data['merged_at']
- author = pr_data['user']['login']
+ merged_at = pr_data["merged_at"]
+ author = pr_data["user"]["login"]
if merged_at is None:
continue
@@ -439,7 +444,7 @@ def _iterate_by_page(org_name, repo_name):
_iterate_by_page(org_name, repo_name)
# convert unix timestamp to Beijing datetime
- bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone('Asia/Shanghai'))
+ bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai"))
bj_start_datetime_str = datetime2str(bj_start_datetime)
contribution_list = counter.to_sorted_list()
@@ -452,7 +457,7 @@ def _iterate_by_page(org_name, repo_name):
if len(author_list) > 0:
xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})"
ylabel = "Contributor"
- title = 'Active Contributor Leaderboard'
+ title = "Active Contributor Leaderboard"
plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True
else:
@@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str:
image_path (str): the path to the image to be uploaded
"""
url = "https://open.feishu.cn/open-apis/im/v1/images"
- form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path
+ form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path
multi_form = MultipartEncoder(form)
headers = {
- 'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token
+ "Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token
}
- headers['Content-Type'] = multi_form.content_type
+ headers["Content-Type"] = multi_form.content_type
response = requests.request("POST", url, headers=headers, data=multi_form).json()
- return response['data']['image_key']
+ return response["data"]["image_key"]
def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
@@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
app_id (str): Lark app id
app_secret (str): Lark app secret
"""
- url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
- data = {'app_id': app_id, 'app_secret': app_secret}
+ url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
+ data = {"app_id": app_id, "app_secret": app_secret}
response = requests.post(url, json=data).json()
- return response['tenant_access_token']
+ return response["tenant_access_token"]
def send_image_to_lark(image_key: str, webhook_url: str) -> None:
@@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str):
requests.post(webhook_url, json=data)
-if __name__ == '__main__':
- GITHUB_TOKEN = os.environ['GITHUB_TOKEN']
- CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png'
- USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png'
+if __name__ == "__main__":
+ GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
+ CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png"
+ USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png"
ORG_NAME = "hpcaitech"
# get all open source repositories
@@ -527,17 +532,19 @@ def send_message_to_lark(message: str, webhook_url: str):
# generate images
contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH)
- engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH)
+ engagement_success = generate_user_engagement_leaderboard_image(
+ GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH
+ )
# upload images
- APP_ID = os.environ['LARK_APP_ID']
- APP_SECRET = os.environ['LARK_APP_SECRET']
+ APP_ID = os.environ["LARK_APP_ID"]
+ APP_SECRET = os.environ["LARK_APP_SECRET"]
LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET)
contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH)
user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
# send message to lark
- LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL']
+ LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"]
message = """本周的社区榜单出炉啦!
1. 开发贡献者榜单
2. 用户互动榜单
diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py
index dc592e4c977b..7374481005ef 100644
--- a/.github/workflows/scripts/generate_release_draft.py
+++ b/.github/workflows/scripts/generate_release_draft.py
@@ -7,27 +7,27 @@
import requests
-COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits'
-TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags'
+COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits"
+TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags"
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--out', type=str, help='output path for the release draft', required=True)
- parser.add_argument('--version', type=str, help='current version to release', required=True)
+ parser.add_argument("--out", type=str, help="output path for the release draft", required=True)
+ parser.add_argument("--version", type=str, help="current version to release", required=True)
return parser.parse_args()
def get_latest_tag_commit(headers=None):
res = requests.get(url=TAGS_API, headers=headers)
data = res.json()
- commit_hash = data[0]['commit']['sha']
- version = data[0]['name']
+ commit_hash = data[0]["commit"]["sha"]
+ version = data[0]["name"]
return commit_hash, version
def get_commit_info(commit_hash, headers=None):
- api = f'{COMMIT_API}/{commit_hash}'
+ api = f"{COMMIT_API}/{commit_hash}"
res = requests.get(url=api, headers=headers)
return res.json()
@@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None):
results = []
while True:
- api = f'{COMMIT_API}?since={since}&per_page=100&page={page}'
+ api = f"{COMMIT_API}?since={since}&per_page=100&page={page}"
resp = requests.get(url=api, headers=headers)
data = resp.json()
@@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None):
def collate_release_info(commit_info_list):
results = dict()
- pattern = pattern = r'\[.*\]'
+ pattern = pattern = r"\[.*\]"
for commit_info in commit_info_list:
- author = commit_info['commit']['author']['name']
+ author = commit_info["commit"]["author"]["name"]
try:
- author_url = commit_info['author']['url']
+ author_url = commit_info["author"]["url"]
except:
# author can be None
author_url = None
- msg = commit_info['commit']['message']
+ msg = commit_info["commit"]["message"]
match = re.search(pattern, msg)
if match:
- tag = match.group().lstrip('[').rstrip(']').capitalize()
+ tag = match.group().lstrip("[").rstrip("]").capitalize()
if tag not in results:
results[tag] = []
results[tag].append((msg, author, author_url))
@@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info):
for msg, author, author_url in v:
# only keep the first line
- msg = msg.split('\n')[0]
+ msg = msg.split("\n")[0]
if author_url:
- item = f'{msg} by [{author}]({author_url})\n'
+ item = f"{msg} by [{author}]({author_url})\n"
else:
- item = f'{msg} by {author}\n'
- text.append(f'- {item}')
+ item = f"{msg} by {author}\n"
+ text.append(f"- {item}")
- text.append('\n')
+ text.append("\n")
# add full change log
text.append(
- f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}')
+ f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}"
+ )
return text
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
- token = os.environ['GITHUB_API_TOKEN']
- headers = {'Authorization': token}
+ token = os.environ["GITHUB_API_TOKEN"]
+ headers = {"Authorization": token}
# get previous release tag
last_release_commit, last_version = get_latest_tag_commit(headers)
last_release_commit_info = get_commit_info(last_release_commit, headers=headers)
- last_release_date = last_release_commit_info['commit']['author']['date']
+ last_release_date = last_release_commit_info["commit"]["author"]["date"]
# get the commits since last release
commit_info = get_all_commit_info(since=last_release_date, headers=headers)
- commit_info = commit_info[:-1] # remove the release commit
+ commit_info = commit_info[:-1] # remove the release commit
# collate into markdown
release_info = collate_release_info(commit_info)
markdown_text = generate_release_post_markdown(args.version, last_version, release_info)
# write into a file
- with open(args.out, 'w') as f:
+ with open(args.out, "w") as f:
for line in markdown_text:
f.write(line)
diff --git a/.github/workflows/scripts/send_message_to_lark.py b/.github/workflows/scripts/send_message_to_lark.py
index a113327a786e..bc005d93c3f5 100644
--- a/.github/workflows/scripts/send_message_to_lark.py
+++ b/.github/workflows/scripts/send_message_to_lark.py
@@ -5,8 +5,8 @@
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('-m', '--message', type=str)
- parser.add_argument('-u', '--url', type=str)
+ parser.add_argument("-m", "--message", type=str)
+ parser.add_argument("-u", "--url", type=str)
return parser.parse_args()
@@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url):
requests.post(webhook_url, json=data)
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
send_message_to_lark(args.message, args.url)
diff --git a/.isort.cfg b/.isort.cfg
index 090aa28e39f3..4f881c8b3dda 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -3,3 +3,4 @@ line_length = 120
multi_line_output=3
include_trailing_comma = true
ignore_comments = true
+profile = black
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 725d266375ef..9871e1184462 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,23 +1,31 @@
repos:
+ - repo: https://github.com/PyCQA/autoflake
+ rev: v2.2.1
+ hooks:
+ - id: autoflake
+ name: autoflake (python)
+ args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
+
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: sort all imports (python)
- - repo: https://github.com/pre-commit/mirrors-yapf
- rev: v0.32.0
+ - repo: https://github.com/psf/black-pre-commit-mirror
+ rev: 23.9.1
hooks:
- - id: yapf
- name: yapf formatter
- args: ['--style=.style.yapf', '--parallel', '--in-place']
+ - id: black
+ name: black formatter
+ args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1
hooks:
- id: clang-format
name: clang formatter
+ types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
diff --git a/.style.yapf b/.style.yapf
deleted file mode 100644
index 05be0dc6a3a5..000000000000
--- a/.style.yapf
+++ /dev/null
@@ -1,5 +0,0 @@
-[style]
-based_on_style = google
-spaces_before_comment = 4
-split_before_logical_operator = true
-column_limit = 120
diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
index 90471ed727b0..04f779821405 100644
--- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
+++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
@@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
- return {'input_ids': input_ids, 'attention_mask': attention_mask}
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
def print_rank_0(*args, **kwargs) -> None:
@@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None:
B = 1024**3
M = 1024**2
K = 1024
- outputs = ''
+ outputs = ""
for name, numel in model_dict.items():
- outputs += f'{name}: '
+ outputs += f"{name}: "
if numel >= B:
- outputs += f'{numel / B:.2f} B\n'
+ outputs += f"{numel / B:.2f} B\n"
elif numel >= M:
- outputs += f'{numel / M:.2f} M\n'
+ outputs += f"{numel / M:.2f} M\n"
elif numel >= K:
- outputs += f'{numel / K:.2f} K\n'
+ outputs += f"{numel / K:.2f} K\n"
else:
- outputs += f'{numel}\n'
+ outputs += f"{numel}\n"
print_rank_0(outputs)
def get_gpt_config(model_name: str) -> OPTConfig:
model_map = {
- '125m': OPTConfig.from_pretrained('facebook/opt-125m'),
- '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
- '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
- '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
- '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
- '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
- '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
- '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
- '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
- '13b': OPTConfig.from_pretrained('facebook/opt-13b'),
+ "125m": OPTConfig.from_pretrained("facebook/opt-125m"),
+ "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
+ "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
+ "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
+ "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
+ "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
+ "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
+ "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
+ "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
+ "13b": OPTConfig.from_pretrained("facebook/opt-13b"),
}
try:
return model_map[model_name]
@@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig:
def main(args):
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif args.strategy == 'colossalai_gemini_cpu':
- strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2_cpu':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
- elif args.strategy == 'colossalai_zero1':
- strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero1_cpu':
- strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ elif args.strategy == "colossalai_gemini_cpu":
+ strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2_cpu":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
+ elif args.strategy == "colossalai_zero1":
+ strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero1_cpu":
+ strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
@@ -103,90 +103,106 @@ def main(args):
if args.use_kernels:
from coati.kernels import convert_to_xformer_model
- actor, critic, initial_model, reward_model = map(convert_to_xformer_model,
- (actor, critic, initial_model, reward_model))
+
+ actor, critic, initial_model, reward_model = map(
+ convert_to_xformer_model, (actor, critic, initial_model, reward_model)
+ )
actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy)
- print_model_numel({
- 'Actor': actor_numel,
- 'Critic': critic_numel,
- 'Initial model': initial_model_numel,
- 'Reward model': reward_model_numel
- })
- performance_evaluator = PerformanceEvaluator(actor_numel,
- critic_numel,
- initial_model_numel,
- reward_model_numel,
- enable_grad_checkpoint=False,
- ignore_episodes=1)
-
- if args.strategy.startswith('colossalai'):
+ print_model_numel(
+ {
+ "Actor": actor_numel,
+ "Critic": critic_numel,
+ "Initial model": initial_model_numel,
+ "Reward model": reward_model_numel,
+ }
+ )
+ performance_evaluator = PerformanceEvaluator(
+ actor_numel,
+ critic_numel,
+ initial_model_numel,
+ reward_model_numel,
+ enable_grad_checkpoint=False,
+ ignore_episodes=1,
+ )
+
+ if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
- tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
- dataloader = DataLoader(random_prompts,
- batch_size=args.experience_batch_size,
- shuffle=True,
- collate_fn=preprocess_batch)
-
- trainer = PPOTrainer(strategy,
- actor,
- critic,
- reward_model,
- initial_model,
- actor_optim,
- critic_optim,
- ptx_coef=0,
- train_batch_size=args.train_batch_size,
- offload_inference_models=args.offload_inference_models,
- max_length=512,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- use_cache=True,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- callbacks=[performance_evaluator])
-
- trainer.fit(prompt_dataloader=dataloader,
- pretrain_dataloader=None,
- num_episodes=args.num_episodes,
- num_update_steps=args.num_update_steps,
- num_collect_steps=args.num_collect_steps)
-
- print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
-
-
-if __name__ == '__main__':
+ dataloader = DataLoader(
+ random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch
+ )
+
+ trainer = PPOTrainer(
+ strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ ptx_coef=0,
+ train_batch_size=args.train_batch_size,
+ offload_inference_models=args.offload_inference_models,
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ use_cache=True,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ callbacks=[performance_evaluator],
+ )
+
+ trainer.fit(
+ prompt_dataloader=dataloader,
+ pretrain_dataloader=None,
+ num_episodes=args.num_episodes,
+ num_update_steps=args.num_update_steps,
+ num_collect_steps=args.num_collect_steps,
+ )
+
+ print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
+
+
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='125m')
- parser.add_argument('--critic_model', default='125m')
- parser.add_argument('--strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
- 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
- ],
- default='ddp')
- parser.add_argument('--num_episodes', type=int, default=3)
- parser.add_argument('--num_collect_steps', type=int, default=8)
- parser.add_argument('--num_update_steps', type=int, default=1)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0)
- parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
- parser.add_argument('--offload_inference_models', action='store_true', default=False)
- parser.add_argument('--use_kernels', action='store_true', default=False)
+ parser.add_argument("--model", default="125m")
+ parser.add_argument("--critic_model", default="125m")
+ parser.add_argument(
+ "--strategy",
+ choices=[
+ "ddp",
+ "colossalai_gemini",
+ "colossalai_gemini_cpu",
+ "colossalai_zero2",
+ "colossalai_zero2_cpu",
+ "colossalai_zero1",
+ "colossalai_zero1_cpu",
+ ],
+ default="ddp",
+ )
+ parser.add_argument("--num_episodes", type=int, default=3)
+ parser.add_argument("--num_collect_steps", type=int, default=8)
+ parser.add_argument("--num_update_steps", type=int, default=1)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0)
+ parser.add_argument("--cuda_mem_frac", type=float, default=1.0)
+ parser.add_argument("--offload_inference_models", action="store_true", default=False)
+ parser.add_argument("--use_kernels", action="store_true", default=False)
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/Chat/benchmarks/ray/1mmt_dummy.py
index 7fc990448805..98ace3869450 100644
--- a/applications/Chat/benchmarks/ray/1mmt_dummy.py
+++ b/applications/Chat/benchmarks/ray/1mmt_dummy.py
@@ -22,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -36,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
- 'local_rank': '0',
- 'rank': '0',
- 'world_size': '1',
- 'master_port': maker_port,
- 'master_addr': master_addr
+ "local_rank": "0",
+ "rank": "0",
+ "world_size": "1",
+ "master_port": maker_port,
+ "master_addr": master_addr,
}
# configure tokenizer
@@ -63,21 +66,27 @@ def model_fn():
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
- reward_model = get_reward_model_from_args(args.critic_model,
- config=critic_cfg).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ reward_model = (
+ get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ )
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
+ detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
@@ -97,15 +106,18 @@ def model_fn():
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
- critic = get_critic_from_args(args.critic_model,
- config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
+ critic = (
+ get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
+ .half()
+ .cuda()
+ )
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
- f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
+ f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
@@ -114,7 +126,8 @@ def trainer_model_fn():
buffer_limit=16,
eval_performance=True,
debug=args.debug,
- ) for i, env_info_trainer in enumerate(env_info_trainers)
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
@@ -122,7 +135,7 @@ def trainer_model_fn():
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
- return {'input_ids': input_ids, 'attention_mask': attn_mask}
+ return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
@@ -138,8 +151,10 @@ def build_dataloader(size):
wait_tasks = []
wait_tasks.append(
- experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
- num_steps=args.experience_steps))
+ experience_holder_ref.workingloop.remote(
+ partial(build_dataloader, dataset_size), num_steps=args.experience_steps
+ )
+ )
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
@@ -148,31 +163,30 @@ def build_dataloader(size):
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--num_trainers', type=int, default=1)
- parser.add_argument('--trainer_strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
- 'colossalai_zero2_cpu'
- ],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/Chat/benchmarks/ray/mmmt_dummy.py
index ca1df22070fc..f8860f2979ee 100644
--- a/applications/Chat/benchmarks/ray/mmmt_dummy.py
+++ b/applications/Chat/benchmarks/ray/mmmt_dummy.py
@@ -22,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -36,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
- env_info_makers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_makers),
- 'master_port': maker_port,
- 'master_addr': master_addr
- } for rank in range(args.num_makers)]
+ env_info_makers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_makers),
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_makers)
+ ]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
@@ -63,14 +69,20 @@ def model_fn():
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
- reward_model = get_reward_model_from_args(args.critic_model,
- config=critic_cfg).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ reward_model = (
+ get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ )
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
@@ -79,7 +91,7 @@ def model_fn():
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
- f'trainer{x}'
+ f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
@@ -103,8 +115,11 @@ def model_fn():
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
- critic = get_critic_from_args(args.critic_model,
- config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
+ critic = (
+ get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
+ .half()
+ .cuda()
+ )
return actor, critic
# configure Trainer
@@ -130,7 +145,7 @@ def trainer_model_fn():
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
- return {'input_ids': input_ids, 'attention_mask': attn_mask}
+ return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
@@ -147,43 +162,48 @@ def build_dataloader(size):
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(
- experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
- num_steps=args.experience_steps))
+ experience_holder_ref.workingloop.remote(
+ partial(build_dataloader, dataset_size), num_steps=args.experience_steps
+ )
+ )
- total_steps = args.experience_batch_size * args.experience_steps * \
- args.num_makers // (args.num_trainers * args.train_batch_size)
+ total_steps = (
+ args.experience_batch_size
+ * args.experience_steps
+ * args.num_makers
+ // (args.num_trainers * args.train_batch_size)
+ )
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--num_makers', type=int, default=1)
- parser.add_argument('--num_trainers', type=int, default=1)
- parser.add_argument('--trainer_strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
- 'colossalai_zero2_cpu'
- ],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ parser.add_argument("--num_makers", type=int, default=1)
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py
index bd4e5460d11e..599b57609775 100644
--- a/applications/Chat/coati/dataset/__init__.py
+++ b/applications/Chat/coati/dataset/__init__.py
@@ -4,7 +4,10 @@
from .utils import is_rank_0
__all__ = [
- 'RmStaticDataset', 'HhRlhfDataset',
- 'SFTDataset', 'SupervisedDataset',
- 'PromptDataset', 'is_rank_0',
+ "RmStaticDataset",
+ "HhRlhfDataset",
+ "SFTDataset",
+ "SupervisedDataset",
+ "PromptDataset",
+ "is_rank_0",
]
diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py
index 465fa867c7ab..f2180d96b0d3 100644
--- a/applications/Chat/coati/dataset/conversation.py
+++ b/applications/Chat/coati/dataset/conversation.py
@@ -49,7 +49,7 @@ def append_message(self, role, message):
def to_gradio_chatbot(self):
ret = []
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
@@ -57,12 +57,14 @@ def to_gradio_chatbot(self):
return ret
def copy(self):
- return Conversation(system=self.system,
- roles=self.roles,
- messages=[[x, y] for x, y in self.messages],
- offset=self.offset,
- sep_style=self.sep_style,
- sep=self.sep)
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ )
def dict(self):
return {
@@ -70,7 +72,7 @@ def dict(self):
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
- "sep": self.sep
+ "sep": self.sep,
}
diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py
index 2c953fffa513..17120e6064b5 100644
--- a/applications/Chat/coati/dataset/prompt_dataset.py
+++ b/applications/Chat/coati/dataset/prompt_dataset.py
@@ -13,11 +13,13 @@
class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self,
- data_path: str,
- tokenizer: transformers.PreTrainedTokenizer,
- max_datasets_size: int = None,
- max_length: int = 96):
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ max_datasets_size: int = None,
+ max_length: int = 96,
+ ):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
self.logger = get_dist_logger()
@@ -30,11 +32,9 @@ def __init__(self,
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
- tokens = tokenizer(instructions,
- return_tensors='pt',
- max_length=max_length,
- padding='max_length',
- truncation=True)
+ tokens = tokenizer(
+ instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
+ )
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py
index 3c4ec8b214bb..3afcd7b69238 100644
--- a/applications/Chat/coati/dataset/reward_dataset.py
+++ b/applications/Chat/coati/dataset/reward_dataset.py
@@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.end_token = tokenizer.eos_token \
- if special_token is None else special_token
-
- chosen = [
- data["prompt"] + data["chosen"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen = {
- "input_ids": chosen_token["input_ids"],
- "attention_mask": chosen_token["attention_mask"]
- }
-
- reject = [
- data["prompt"] + data["rejected"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject = {
- "input_ids": reject_token["input_ids"],
- "attention_mask": reject_token["attention_mask"]
- }
+ self.end_token = tokenizer.eos_token if special_token is None else special_token
+
+ chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+
+ reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
- self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
+ return (
+ self.chosen["input_ids"][idx],
+ self.chosen["attention_mask"][idx],
+ self.reject["input_ids"][idx],
+ self.reject["attention_mask"][idx],
+ )
# Anthropic/hh-rlhf
@@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.end_token = tokenizer.eos_token \
- if special_token is None else special_token
-
- chosen = [
- data["chosen"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen = {
- "input_ids": chosen_token["input_ids"],
- "attention_mask": chosen_token["attention_mask"]
- }
-
- reject = [
- data["rejected"] + self.end_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject = {
- "input_ids": reject_token["input_ids"],
- "attention_mask": reject_token["attention_mask"]
- }
+ self.end_token = tokenizer.eos_token if special_token is None else special_token
+
+ chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+
+ reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
- self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
+ return (
+ self.chosen["input_ids"][idx],
+ self.chosen["attention_mask"][idx],
+ self.reject["input_ids"][idx],
+ self.reject["attention_mask"][idx],
+ )
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 2959d3fac81c..d6be09ca5cc9 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -16,10 +16,11 @@
from typing import Dict, Sequence, Tuple
import torch
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
@@ -28,32 +29,33 @@
IGNORE_INDEX = -100
PROMPT_DICT = {
- "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
- "prompt_no_input": ("Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Response:"),
+ "prompt_input": (
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
+ ),
+ "prompt_no_input": (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Response:"
+ ),
}
-def _preprocess(sources: Sequence[str],
- targets: Sequence[str],
- tokenizer: PreTrainedTokenizer,
- max_length: int,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+def _preprocess(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
- sequences_token = tokenizer(sequences,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- sources_token = tokenizer(sources,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
+ sequences_token = tokenizer(
+ sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ sources_token = tokenizer(
+ sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
@@ -64,23 +66,24 @@ def _preprocess(sources: Sequence[str],
labels[i][:source_len] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
- labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
+ labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
-def _preprocess_chatglm(sources: Sequence[str],
- targets: Sequence[str],
- tokenizer: PreTrainedTokenizer,
- max_length: int,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+def _preprocess_chatglm(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preprocess the data by tokenizing.
None for attention mask, ChatGLM will calculate attention mask according to input ids
"""
-
+
labels = []
input_ids = []
for source, target in zip(sources, targets):
@@ -90,16 +93,16 @@ def _preprocess_chatglm(sources: Sequence[str],
# truncate
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
truncate_length = max(0, len(input_id) - max_length)
- input_id = input_id[truncate_length: ]
+ input_id = input_id[truncate_length:]
if truncate_length == len(source_id) + 1:
- input_id = sp_token_list + input_id[1: ]
+ input_id = sp_token_list + input_id[1:]
elif truncate_length > len(source_id) + 1:
- input_id = sp_token_list + input_id[2: ]
-
+ input_id = sp_token_list + input_id[2:]
+
context_length = input_id.index(tokenizer.bos_token_id)
mask_position = context_length - 1
- label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
-
+ label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
+
pad_len = max_length - len(input_id)
input_id = input_id + [tokenizer.pad_token_id] * pad_len
input_ids.append(input_id)
@@ -117,25 +120,18 @@ class SFTDataset(Dataset):
max_length: max length of input
"""
- def __init__(self,
- dataset: Dict,
- tokenizer: PreTrainedTokenizer,
- max_length: int = 512
- ) -> None:
+ def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
super().__init__()
self.input_ids = []
sources = [data["prompt"] for data in dataset]
- targets = [
- data["completion"] + tokenizer.eos_token
- for data in tqdm(dataset, disable=not is_rank_0())
- ]
+ targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
if isinstance(tokenizer, ChatGLMTokenizer):
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
+ sources, targets, tokenizer, max_length
+ )
else:
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
@@ -143,22 +139,17 @@ def __len__(self):
def __getitem__(self, idx):
if self.attention_mask is not None:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
else:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self,
- data_path: str,
- tokenizer: PreTrainedTokenizer,
- max_datasets_size: int = None,
- max_length: int = 512):
+ def __init__(
+ self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
+ ):
super().__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
@@ -174,18 +165,15 @@ def __init__(self,
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict
]
- targets = [
- example['output'] + tokenizer.eos_token
- for example in list_data_dict
- ]
+ targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer):
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
+ sources, targets, tokenizer, max_length
+ )
else:
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
@@ -193,9 +181,6 @@ def __len__(self):
def __getitem__(self, idx):
if self.attention_mask is not None:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
else:
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx])
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py
index c0188dc4a471..f2a48d0a3b20 100644
--- a/applications/Chat/coati/experience_buffer/__init__.py
+++ b/applications/Chat/coati/experience_buffer/__init__.py
@@ -1,4 +1,4 @@
from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer
-__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
+__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]
diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py
index 9ccdc935d506..7047785308f3 100644
--- a/applications/Chat/coati/experience_buffer/base.py
+++ b/applications/Chat/coati/experience_buffer/base.py
@@ -7,9 +7,9 @@
class ExperienceBuffer(ABC):
"""Experience buffer base class. It stores experience.
- Args:
- sample_batch_size (int): Batch size when sampling.
- limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py
index bd5213b38993..acc0fbe88ab4 100644
--- a/applications/Chat/coati/experience_buffer/naive.py
+++ b/applications/Chat/coati/experience_buffer/naive.py
@@ -11,23 +11,23 @@
class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive experience buffer class. It stores experience.
- Args:
- sample_batch_size (int): Batch size when sampling.
- limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
- cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
super().__init__(sample_batch_size, limit)
self.cpu_offload = cpu_offload
- self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
+ self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
if self.cpu_offload:
- experience.to_device(torch.device('cpu'))
+ experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience)
self.items.extend(items)
if self.limit > 0:
diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py
index c2a34212e2f4..baedbebd184f 100644
--- a/applications/Chat/coati/experience_buffer/utils.py
+++ b/applications/Chat/coati/experience_buffer/utils.py
@@ -21,6 +21,7 @@ class BufferItem:
"A" is the number of actions.
"""
+
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
@@ -33,8 +34,7 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
- keys = ('sequences', 'action_log_probs', 'values',
- 'reward', 'advantages', 'attention_mask', 'action_mask')
+ keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
@@ -49,22 +49,21 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items
-def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
- assert side in ('left', 'right')
+def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
+ assert side in ("left", "right")
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
- padding = (pad_len, 0) if side == 'left' else (0, pad_len)
+ padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
- to_pad_keys = set(('action_log_probs', 'action_mask'))
- keys = ('sequences', 'action_log_probs', 'values',
- 'reward', 'advantages', 'attention_mask', 'action_mask')
+ to_pad_keys = set(("action_log_probs", "action_mask"))
+ keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py
index 39ca7576b227..06452292e77c 100644
--- a/applications/Chat/coati/experience_maker/__init__.py
+++ b/applications/Chat/coati/experience_maker/__init__.py
@@ -1,4 +1,4 @@
from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker
-__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
+__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]
diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py
index b4646f282f0c..727f0a4a52e8 100644
--- a/applications/Chat/coati/experience_maker/base.py
+++ b/applications/Chat/coati/experience_maker/base.py
@@ -24,6 +24,7 @@ class Experience:
"A" is the number of actions.
"""
+
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
@@ -58,13 +59,9 @@ def pin_memory(self):
class ExperienceMaker(ABC):
-
- def __init__(self,
- actor: Actor,
- critic: nn.Module,
- reward_model: nn.Module,
- initial_model: Actor,
- kl_coef: float = 0.1) -> None:
+ def __init__(
+ self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
+ ) -> None:
super().__init__()
self.actor = actor
self.critic = critic
diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py
index 496f8ab445fc..30dfd8e0b9bc 100644
--- a/applications/Chat/coati/experience_maker/naive.py
+++ b/applications/Chat/coati/experience_maker/naive.py
@@ -23,22 +23,21 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie
# calculate auxiliary tensors
attention_mask = None
- pad_token_id = generate_kwargs.get('pad_token_id', None)
+ pad_token_id = generate_kwargs.get("pad_token_id", None)
if pad_token_id is not None:
- attention_mask = sequences.not_equal(pad_token_id)\
- .to(dtype=torch.long, device=sequences.device)
+ attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
- eos_token_id = generate_kwargs.get('eos_token_id', None)
+ eos_token_id = generate_kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
- action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
- action_mask = action_mask[:, -(sequences.size(1) - input_len):]
+ action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask)
diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py
index 230eedf7ecba..96d40c7c4709 100644
--- a/applications/Chat/coati/kernels/__init__.py
+++ b/applications/Chat/coati/kernels/__init__.py
@@ -1,6 +1,6 @@
from .wrapper import convert_to_xformer_model, recover_from_xformer_model
__all__ = [
- 'convert_to_xformer_model',
- 'recover_from_xformer_model',
+ "convert_to_xformer_model",
+ "recover_from_xformer_model",
]
diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py
index e99f9c2247d1..d1eb139187f3 100644
--- a/applications/Chat/coati/kernels/opt_attn.py
+++ b/applications/Chat/coati/kernels/opt_attn.py
@@ -21,11 +21,12 @@ def forward(
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
if not self.training:
- return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask,
- output_attentions)
+ return super().forward(
+ hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
+ )
"""Input shape: Batch x Time x Channel"""
- assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask'
- assert not output_attentions, 'Xformers attention does not support output_attentions'
+ assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
+ assert not output_attentions, "Xformers attention does not support output_attentions"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
@@ -69,12 +70,14 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- attn_output = xops.memory_efficient_attention(query_states,
- key_states,
- value_states,
- attn_bias=xops.LowerTriangularMask(),
- p=self.dropout if self.training else 0.0,
- scale=self.scaling)
+ attn_output = xops.memory_efficient_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_bias=xops.LowerTriangularMask(),
+ p=self.dropout if self.training else 0.0,
+ scale=self.scaling,
+ )
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py
index 0a296a863756..ad4a525b4af2 100644
--- a/applications/Chat/coati/models/__init__.py
+++ b/applications/Chat/coati/models/__init__.py
@@ -3,6 +3,13 @@
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [
- 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
- 'LoRAModule', 'convert_to_lora_module'
+ "Actor",
+ "Critic",
+ "RewardModel",
+ "PolicyLoss",
+ "ValueLoss",
+ "LogSigLoss",
+ "LogExpLoss",
+ "LoRAModule",
+ "convert_to_lora_module",
]
diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py
index c5f748a0c85a..5c9905bb2224 100644
--- a/applications/Chat/coati/models/base/__init__.py
+++ b/applications/Chat/coati/models/base/__init__.py
@@ -9,7 +9,7 @@
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
- For Actor, Critic and RewardModel, return ``model.model``,
+ For Actor, Critic and RewardModel, return ``model.model``,
it's usually a ``transformers.PreTrainedModel``.
Args:
@@ -18,9 +18,10 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
Returns:
nn.Module: the base model
"""
- assert isinstance(model, (Actor, Critic, RewardModel)), \
- f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
+ assert isinstance(
+ model, (Actor, Critic, RewardModel)
+ ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
return model.model
-__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
+__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]
diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py
index 6842f81d9b87..979f9318be50 100644
--- a/applications/Chat/coati/models/base/actor.py
+++ b/applications/Chat/coati/models/base/actor.py
@@ -16,18 +16,17 @@ class Actor(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
def forward(
- self,
- input_ids: torch.LongTensor,
- attention_mask: Optional[torch.Tensor] = None,
- **model_kwargs, # HACK: `generate` method may pass more kwargs
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **model_kwargs, # HACK: `generate` method may pass more kwargs
) -> torch.Tensor:
- """Returns model output.
- """
+ """Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output
diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py
index e68a743a7762..54ab7fa47d48 100644
--- a/applications/Chat/coati/models/base/critic.py
+++ b/applications/Chat/coati/models/base/critic.py
@@ -23,22 +23,23 @@ def __init__(
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
- lora_train_bias: str = 'none',
+ lora_train_bias: str = "none",
use_action_mask: bool = False,
) -> None:
-
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora()
- def forward(self,
- sequences: torch.LongTensor,
- action_mask: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(
+ self,
+ sequences: torch.LongTensor,
+ action_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
- last_hidden_states = outputs['last_hidden_state']
+ last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states).squeeze(-1)
diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py
index ce8c0a1d3568..1a70c6cc12bb 100644
--- a/applications/Chat/coati/models/base/reward_model.py
+++ b/applications/Chat/coati/models/base/reward_model.py
@@ -17,11 +17,13 @@ class RewardModel(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- model: nn.Module,
- value_head: Optional[nn.Module] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ model: nn.Module,
+ value_head: Optional[nn.Module] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
@@ -35,7 +37,7 @@ def __init__(self,
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
- last_hidden_states = outputs['last_hidden_state']
+ last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states)[:, :-1]
- value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
+ value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
return value
diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py
index d0e7f7b1ef94..7af199a67d3b 100644
--- a/applications/Chat/coati/models/bloom/__init__.py
+++ b/applications/Chat/coati/models/bloom/__init__.py
@@ -2,4 +2,4 @@
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
-__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
+__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"]
diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py
index d7577f096493..73855a2245e7 100644
--- a/applications/Chat/coati/models/bloom/bloom_actor.py
+++ b/applications/Chat/coati/models/bloom/bloom_actor.py
@@ -1,7 +1,6 @@
from typing import Optional
-import torch
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomForCausalLM
from ..base import Actor
@@ -18,12 +17,14 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py
index a3716ca94138..b2d838f7ffc5 100644
--- a/applications/Chat/coati/models/bloom/bloom_critic.py
+++ b/applications/Chat/coati/models/bloom/bloom_critic.py
@@ -1,8 +1,7 @@
from typing import Optional
-import torch
import torch.nn as nn
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomModel
from ..base import Critic
@@ -18,12 +17,14 @@ class BLOOMCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py
index e6ca9b1d4851..c09457ddc8c7 100644
--- a/applications/Chat/coati/models/bloom/bloom_rm.py
+++ b/applications/Chat/coati/models/bloom/bloom_rm.py
@@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomModel
from ..base import RewardModel
@@ -17,11 +17,13 @@ class BLOOMRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py
index 373f19553fdc..5956f5a8e91b 100644
--- a/applications/Chat/coati/models/chatglm/__init__.py
+++ b/applications/Chat/coati/models/chatglm/__init__.py
@@ -1,3 +1,3 @@
from .chatglm_actor import ChatGLMActor
-__all__ = ['ChatGLMActor']
\ No newline at end of file
+__all__ = ["ChatGLMActor"]
diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py
index c35d994e9319..00a61561ee47 100644
--- a/applications/Chat/coati/models/chatglm/chatglm_actor.py
+++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py
@@ -1,11 +1,9 @@
from typing import Optional
-import torch
+from ..base import Actor
from .configuration_chatglm import ChatGLMConfig
from .modeling_chatglm import ChatGLMForConditionalGeneration
-from ..base import Actor
-
class ChatGLMActor(Actor):
"""
@@ -19,10 +17,9 @@ class ChatGLMActor(Actor):
do not support lora for now.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[ChatGLMConfig] = None,
- checkpoint: bool = False) -> None:
+ def __init__(
+ self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False
+ ) -> None:
if pretrained is not None:
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
elif config is not None:
@@ -31,4 +28,4 @@ def __init__(self,
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
if checkpoint:
model.gradient_checkpointing_enable()
- super().__init__(model, lora_rank=0, lora_train_bias='none')
+ super().__init__(model, lora_rank=0, lora_train_bias="none")
diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
index f7717f7e68b6..221ef044b470 100644
--- a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
+++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
@@ -2,15 +2,14 @@
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
"""
"""Tokenization classes for ChatGLM."""
-from typing import List, Optional, Union
import os
+from typing import Dict, List, Optional, Union
-from transformers.tokenization_utils import PreTrainedTokenizer
-from transformers.utils import logging, PaddingStrategy
-from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
-from typing import Dict
-import sentencepiece as spm
import numpy as np
+import sentencepiece as spm
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
+from transformers.utils import PaddingStrategy, logging
logger = logging.get_logger(__name__)
@@ -52,11 +51,11 @@ def __len__(self):
class SPTokenizer:
def __init__(
- self,
- vocab_file,
- num_image_tokens=20000,
- max_blank_length=80,
- byte_fallback=True,
+ self,
+ vocab_file,
+ num_image_tokens=20000,
+ max_blank_length=80,
+ byte_fallback=True,
):
assert vocab_file is not None
self.vocab_file = vocab_file
@@ -100,9 +99,7 @@ def _preprocess(self, text: str, linebreak=True, whitespaces=True):
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
return text
- def encode(
- self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
- ) -> List[int]:
+ def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (\n) in text.
@@ -136,9 +133,7 @@ def decode_tokens(self, tokens: List[str]) -> str:
text = self.postprocess(text)
return text
- def tokenize(
- self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
- ) -> List[str]:
+ def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (\n) in text.
@@ -181,20 +176,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(
- self,
- vocab_file,
- do_lower_case=False,
- remove_space=False,
- bos_token='',
- eos_token='',
- end_token='',
- mask_token='[MASK]',
- gmask_token='[gMASK]',
- padding_side="left",
- pad_token="",
- unk_token="",
- num_image_tokens=20000,
- **kwargs
+ self,
+ vocab_file,
+ do_lower_case=False,
+ remove_space=False,
+ bos_token="",
+ eos_token="",
+ end_token="",
+ mask_token="[MASK]",
+ gmask_token="[gMASK]",
+ padding_side="left",
+ pad_token="",
+ unk_token="",
+ num_image_tokens=20000,
+ **kwargs,
) -> None:
super().__init__(
do_lower_case=do_lower_case,
@@ -208,7 +203,7 @@ def __init__(
pad_token=pad_token,
unk_token=unk_token,
num_image_tokens=num_image_tokens,
- **kwargs
+ **kwargs,
)
self.do_lower_case = do_lower_case
@@ -243,11 +238,11 @@ def end_token_id(self) -> Optional[int]:
@property
def vocab_size(self):
- """ Returns vocab size """
+ """Returns vocab size"""
return self.sp_tokenizer.num_tokens
def get_vocab(self):
- """ Returns vocab as a dict """
+ """Returns vocab as a dict"""
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
@@ -264,7 +259,7 @@ def preprocess_text(self, inputs):
return outputs
def _tokenize(self, text, **kwargs):
- """ Returns a tokenized string. """
+ """Returns a tokenized string."""
text = self.preprocess_text(text)
seq = self.sp_tokenizer.tokenize(text)
@@ -274,11 +269,7 @@ def _tokenize(self, text, **kwargs):
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.sp_tokenizer.decode_tokens(tokens)
- def _decode(
- self,
- token_ids: Union[int, List[int]],
- **kwargs
- ) -> str:
+ def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if len(token_ids) == 0:
@@ -288,7 +279,7 @@ def _decode(
return super()._decode(token_ids, **kwargs)
def _convert_token_to_id(self, token):
- """ Converts a token (str) in an id using the vocab. """
+ """Converts a token (str) in an id using the vocab."""
return self.sp_tokenizer[token]
def _convert_id_to_token(self, index):
@@ -309,13 +300,11 @@ def save_vocabulary(self, save_directory, filename_prefix=None):
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
- vocab_file = os.path.join(
- save_directory, self.vocab_files_names["vocab_file"]
- )
+ vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
else:
vocab_file = save_directory
- with open(self.vocab_file, 'rb') as fin:
+ with open(self.vocab_file, "rb") as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
@@ -324,7 +313,7 @@ def save_vocabulary(self, save_directory, filename_prefix=None):
return (vocab_file,)
def build_inputs_with_special_tokens(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
@@ -343,19 +332,19 @@ def build_inputs_with_special_tokens(
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
gmask_id = self.sp_tokenizer[self.gmask_token]
- eos_id = self.sp_tokenizer[self.eos_token]
+ self.sp_tokenizer[self.eos_token]
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1
return token_ids_0
def _pad(
- self,
- encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
- max_length: Optional[int] = None,
- padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
- pad_to_multiple_of: Optional[int] = None,
- return_attention_mask: Optional[bool] = None,
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
@@ -421,17 +410,23 @@ def _pad(
mask_position = required_input.index(mask_token)
position_ids[context_length:] = mask_position
block_position_ids = np.concatenate(
- [np.zeros(context_length, dtype=np.int64),
- np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
+ [
+ np.zeros(context_length, dtype=np.int64),
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64),
+ ]
+ )
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
- encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
- pad_width=[(0, 0), (difference, 0), (difference, 0)],
- mode='constant', constant_values=True)
+ encoded_inputs["attention_mask"] = np.pad(
+ encoded_inputs["attention_mask"],
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
+ mode="constant",
+ constant_values=True,
+ )
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
@@ -439,8 +434,9 @@ def _pad(
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
if "position_ids" in encoded_inputs:
- encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
- pad_width=[(0, 0), (difference, 0)])
+ encoded_inputs["position_ids"] = np.pad(
+ encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)]
+ )
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
- return encoded_inputs
\ No newline at end of file
+ return encoded_inputs
diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
index d0e3f6cc63d7..a6d2ccd18715 100644
--- a/applications/Chat/coati/models/chatglm/configuration_chatglm.py
+++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
@@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
- ```
-"""
+ ```"""
model_type = "chatglm"
def __init__(
- self,
- vocab_size=130528,
- hidden_size=4096,
- num_layers=28,
- num_attention_heads=32,
- layernorm_epsilon=1e-5,
- use_cache=True,
- bos_token_id=130004,
- eos_token_id=130005,
- mask_token_id=130000,
- gmask_token_id=130001,
- pad_token_id=3,
- max_sequence_length=2048,
- inner_hidden_size=16384,
- position_encoding_2d=True,
- quantization_bit=0,
- pre_seq_len=None,
- prefix_projection=False,
- **kwargs
+ self,
+ vocab_size=130528,
+ hidden_size=4096,
+ num_layers=28,
+ num_attention_heads=32,
+ layernorm_epsilon=1e-5,
+ use_cache=True,
+ bos_token_id=130004,
+ eos_token_id=130005,
+ mask_token_id=130000,
+ gmask_token_id=130001,
+ pad_token_id=3,
+ max_sequence_length=2048,
+ inner_hidden_size=16384,
+ position_encoding_2d=True,
+ quantization_bit=0,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs,
):
self.num_layers = num_layers
self.vocab_size = vocab_size
@@ -99,9 +98,4 @@ def __init__(
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- **kwargs
- )
\ No newline at end of file
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
index 77e7d0d8ea09..d1d15c68ffd8 100644
--- a/applications/Chat/coati/models/chatglm/modeling_chatglm.py
+++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
@@ -4,41 +4,40 @@
""" PyTorch ChatGLM model. """
-import math
import copy
+import math
import os
-import warnings
import re
import sys
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
import torch.nn.functional as F
+import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init
-from typing import Optional, Tuple, Union, List, Callable, Dict, Any
-
-from transformers.utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
-)
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
- CausalLMOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import logging
-from transformers.generation.logits_process import LogitsProcessor
-from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels
-if sys.platform != 'darwin':
+if sys.platform != "darwin":
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
@@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
- n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
- for n in name
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
@@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
array = np.transpose(array)
try:
assert (
- pointer.shape == array.shape
+ pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
@@ -153,7 +152,7 @@ def __init__(self, config):
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
- torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
@@ -170,8 +169,7 @@ def forward(self, prefix: torch.Tensor):
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
- return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
- (1.0 + 0.044715 * x * x)))
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def gelu(x):
@@ -181,21 +179,22 @@ def gelu(x):
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__()
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq.half()
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
- self.register_buffer('inv_freq', inv_freq)
+ self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
- error_msgs):
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
pass
def forward(self, x, seq_dim=1, seq_len=None):
@@ -204,7 +203,7 @@ def forward(self, x, seq_dim=1, seq_len=None):
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
@@ -230,30 +229,31 @@ def _apply(self, fn):
def rotate_half(x):
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
@torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
- cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
- F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
+ position_id, sin.squeeze(1)
+ ).unsqueeze(2)
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
def attention_fn(
- self,
- query_layer,
- key_layer,
- value_layer,
- attention_mask,
- hidden_size_per_partition,
- layer_id,
- layer_past=None,
- scaling_attention_score=True,
- use_cache=False,
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ hidden_size_per_partition,
+ layer_id,
+ layer_past=None,
+ scaling_attention_score=True,
+ use_cache=False,
):
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
@@ -285,7 +285,9 @@ def attention_fn(
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
matmul_result = torch.zeros(
- 1, 1, 1,
+ 1,
+ 1,
+ 1,
dtype=query_layer.dtype,
device=query_layer.device,
)
@@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs):
class SelfAttention(torch.nn.Module):
- def __init__(self, hidden_size, num_attention_heads,
- layer_id, hidden_size_per_attention_head=None, bias=True,
- params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
+ def __init__(
+ self,
+ hidden_size,
+ num_attention_heads,
+ layer_id,
+ hidden_size_per_attention_head=None,
+ bias=True,
+ params_dtype=torch.float,
+ position_encoding_2d=True,
+ empty_init=True,
+ ):
if empty_init:
init_method = skip_init
else:
@@ -410,8 +420,7 @@ def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
- def split_tensor_along_last_dim(self, tensor, num_partitions,
- contiguous_split_chunks=False):
+ def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
@@ -431,14 +440,14 @@ def split_tensor_along_last_dim(self, tensor, num_partitions,
return tensor_list
def forward(
- self,
- hidden_states: torch.Tensor,
- position_ids,
- attention_mask: torch.Tensor,
- layer_id,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
@@ -462,8 +471,10 @@ def forward(
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
- position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
- position_ids[:, 1, :].transpose(0, 1).contiguous()
+ position_ids, block_position_ids = (
+ position_ids[:, 0, :].transpose(0, 1).contiguous(),
+ position_ids[:, 1, :].transpose(0, 1).contiguous(),
+ )
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
@@ -484,7 +495,7 @@ def forward(
hidden_size_per_partition=self.hidden_size_per_partition,
layer_id=layer_id,
layer_past=layer_past,
- use_cache=use_cache
+ use_cache=use_cache,
)
output = self.dense(context_layer)
@@ -509,8 +520,16 @@ def forward(self, x):
class GLU(torch.nn.Module):
- def __init__(self, hidden_size, inner_hidden_size=None,
- layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
+ def __init__(
+ self,
+ hidden_size,
+ inner_hidden_size=None,
+ layer_id=None,
+ bias=True,
+ activation_func=gelu,
+ params_dtype=torch.float,
+ empty_init=True,
+ ):
super(GLU, self).__init__()
if empty_init:
init_method = skip_init
@@ -557,19 +576,19 @@ def forward(self, hidden_states):
class GLMBlock(torch.nn.Module):
def __init__(
- self,
- hidden_size,
- num_attention_heads,
- layernorm_epsilon,
- layer_id,
- inner_hidden_size=None,
- hidden_size_per_attention_head=None,
- layernorm=LayerNorm,
- use_bias=True,
- params_dtype=torch.float,
- num_layers=28,
- position_encoding_2d=True,
- empty_init=True
+ self,
+ hidden_size,
+ num_attention_heads,
+ layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=None,
+ hidden_size_per_attention_head=None,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=torch.float,
+ num_layers=28,
+ position_encoding_2d=True,
+ empty_init=True,
):
super(GLMBlock, self).__init__()
# Set output layer initialization if not provided.
@@ -590,7 +609,7 @@ def __init__(
bias=use_bias,
params_dtype=params_dtype,
position_encoding_2d=self.position_encoding_2d,
- empty_init=empty_init
+ empty_init=empty_init,
)
# Layernorm on the input data.
@@ -605,18 +624,18 @@ def __init__(
bias=use_bias,
layer_id=layer_id,
params_dtype=params_dtype,
- empty_init=empty_init
+ empty_init=empty_init,
)
def forward(
- self,
- hidden_states: torch.Tensor,
- position_ids,
- attention_mask: torch.Tensor,
- layer_id,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
@@ -635,7 +654,7 @@ def forward(
layer_id=layer_id,
layer_past=layer_past,
use_cache=use_cache,
- output_attentions=output_attentions
+ output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
@@ -702,10 +721,15 @@ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i]
- block_position_ids = [torch.cat((
- torch.zeros(context_length, dtype=torch.long, device=device),
- torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
- )) for context_length in context_lengths]
+ block_position_ids = [
+ torch.cat(
+ (
+ torch.zeros(context_length, dtype=torch.long, device=device),
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
+ )
+ )
+ for context_length in context_lengths
+ ]
block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else:
@@ -823,9 +847,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True):
self.prefix_projection = config.prefix_projection
self.word_embeddings = init_method(
- torch.nn.Embedding,
- num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
- dtype=self.params_dtype
+ torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
)
self.gradient_checkpointing = False
@@ -841,12 +863,10 @@ def get_layer(layer_id):
use_bias=True,
params_dtype=self.params_dtype,
position_encoding_2d=self.position_encoding_2d,
- empty_init=empty_init
+ empty_init=empty_init,
)
- self.layers = torch.nn.ModuleList(
- [get_layer(layer_id) for layer_id in range(self.num_layers)]
- )
+ self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
@@ -876,7 +896,7 @@ def get_prompt(self, batch_size, device, dtype=torch.half):
self.pre_seq_len,
self.num_layers * 2,
self.num_attention_heads,
- self.hidden_size // self.num_attention_heads
+ self.hidden_size // self.num_attention_heads,
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
@@ -891,18 +911,17 @@ def get_prompt(self, batch_size, device, dtype=torch.half):
config_class=_CONFIG_FOR_DOC,
)
def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
- inputs_embeds: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -931,17 +950,14 @@ def forward(
if past_key_values is None:
if self.pre_seq_len is not None:
- past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
- dtype=inputs_embeds.dtype)
+ past_key_values = self.get_prompt(
+ batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
+ )
else:
past_key_values = tuple([None] * len(self.layers))
if attention_mask is None:
- attention_mask = self.get_masks(
- input_ids,
- device=input_ids.device
- )
-
+ attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
@@ -955,15 +971,13 @@ def forward(
use_gmasks.append(use_gmask)
position_ids = self.get_position_ids(
- input_ids,
- mask_positions=mask_positions,
- device=input_ids.device,
- use_gmasks=use_gmasks
+ input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
)
if self.pre_seq_len is not None and attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
- attention_mask.device)
+ attention_mask.device
+ )
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
@@ -980,7 +994,6 @@ def forward(
attention_mask = attention_mask.to(hidden_states.device)
for i, layer in enumerate(self.layers):
-
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_past = past_key_values[i]
@@ -994,7 +1007,7 @@ def forward(
torch.tensor(i),
layer_past,
use_cache,
- output_attentions
+ output_attentions,
)
else:
layer_ret = layer(
@@ -1004,7 +1017,7 @@ def forward(
layer_id=torch.tensor(i),
layer_past=layer_past,
use_cache=use_cache,
- output_attentions=output_attentions
+ output_attentions=output_attentions,
)
hidden_states = layer_ret[0]
@@ -1049,13 +1062,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True):
self.transformer = ChatGLMModel(config, empty_init=empty_init)
- self.lm_head = init_method(
- nn.Linear,
- config.hidden_size,
- config.vocab_size,
- bias=False,
- dtype=torch.half
- )
+ self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
self.config = config
@@ -1087,32 +1094,29 @@ def _update_model_kwargs_for_generation(
attention_mask = model_kwargs["attention_mask"]
if attention_mask is not None and attention_mask.dtype == torch.bool:
attention_mask = torch.cat(
- [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
+ )
new_attention_mask = attention_mask[:, :, -1:].clone()
new_attention_mask[..., -1] = False
- model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, new_attention_mask], dim=2
- )
+ model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
# update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id[:, 1, :] += 1
- model_kwargs["position_ids"] = torch.cat(
- [position_ids, new_position_id], dim=-1
- )
+ model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
return model_kwargs
def prepare_inputs_for_generation(
- self,
- input_ids: torch.LongTensor,
- past: Optional[torch.Tensor] = None,
- past_key_values: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- **kwargs
+ self,
+ input_ids: torch.LongTensor,
+ past: Optional[torch.Tensor] = None,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs,
) -> dict:
batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
@@ -1137,11 +1141,17 @@ def prepare_inputs_for_generation(
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
if self.position_encoding_2d:
position_ids = torch.tensor(
- [[mask_position, seq_length - context_length] for mask_position, context_length in
- zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
+ [
+ [mask_position, seq_length - context_length]
+ for mask_position, context_length in zip(mask_positions, context_lengths)
+ ],
+ dtype=torch.long,
+ device=input_ids.device,
+ ).unsqueeze(-1)
else:
- position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
- device=input_ids.device).unsqueeze(-1)
+ position_ids = torch.tensor(
+ [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
+ ).unsqueeze(-1)
if past is None:
past = past_key_values
@@ -1149,44 +1159,38 @@ def prepare_inputs_for_generation(
"input_ids": last_token,
"past_key_values": past,
"position_ids": position_ids,
- "attention_mask": attention_mask
+ "attention_mask": attention_mask,
}
else:
if attention_mask is not None and attention_mask.dtype != torch.bool:
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
attention_mask = None
if attention_mask is None:
- attention_mask = self.get_masks(
- input_ids,
- device=input_ids.device
- )
+ attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
position_ids = self.get_position_ids(
- input_ids,
- device=input_ids.device,
- mask_positions=mask_positions,
- use_gmasks=use_gmasks
+ input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
)
return {
"input_ids": input_ids,
"past_key_values": past,
"position_ids": position_ids,
- "attention_mask": attention_mask
+ "attention_mask": attention_mask,
}
def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1235,7 +1239,7 @@ def forward(
@staticmethod
def _reorder_cache(
- past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
@@ -1268,15 +1272,33 @@ def process_response(self, response):
return response
@torch.no_grad()
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_length: int = 2048,
+ num_beams=1,
+ do_sample=True,
+ top_p=0.7,
+ temperature=0.95,
+ logits_processor=None,
+ **kwargs,
+ ):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ gen_kwargs = {
+ "max_length": max_length,
+ "num_beams": num_beams,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
if not history:
prompt = query
else:
@@ -1287,22 +1309,38 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
outputs = self.generate(**inputs, **gen_kwargs)
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@torch.no_grad()
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_length: int = 2048,
+ do_sample=True,
+ top_p=0.7,
+ temperature=0.95,
+ logits_processor=None,
+ **kwargs,
+ ):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ gen_kwargs = {
+ "max_length": max_length,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
if not history:
prompt = query
else:
@@ -1313,7 +1351,7 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
for outputs in self.stream_generate(**inputs, **gen_kwargs):
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
new_history = history + [(query, response)]
@@ -1321,13 +1359,13 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No
@torch.no_grad()
def stream_generate(
- self,
- input_ids,
- generation_config: Optional[GenerationConfig] = None,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
- **kwargs,
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ **kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py
index de0d63f95f50..e3afac88c7a7 100644
--- a/applications/Chat/coati/models/generation.py
+++ b/applications/Chat/coati/models/generation.py
@@ -16,9 +16,9 @@
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-def _prepare_logits_processor(top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None) -> LogitsProcessorList:
+def _prepare_logits_processor(
+ top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
+) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
-def _sample(model: Actor,
- input_ids: torch.Tensor,
- max_length: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
+def _sample(
+ model: Actor,
+ input_ids: torch.Tensor,
+ max_length: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
@@ -56,11 +58,12 @@ def _sample(model: Actor,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
- model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
- if prepare_inputs_fn is not None else {'input_ids': input_ids}
+ model_inputs = (
+ prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
+ )
outputs = model(**model_inputs)
- next_token_logits = outputs['logits'][:, -1, :]
+ next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
@@ -90,20 +93,22 @@ def _sample(model: Actor,
@torch.no_grad()
-def generate(model: Actor,
- input_ids: torch.Tensor,
- max_length: int,
- num_beams: int = 1,
- do_sample: bool = True,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
+def generate(
+ model: Actor,
+ input_ids: torch.Tensor,
+ max_length: int,
+ num_beams: int = 1,
+ do_sample: bool = True,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
@@ -121,26 +126,28 @@ def generate(model: Actor,
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
- is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
- is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
- is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
+ is_greedy_gen_mode = (num_beams == 1) and do_sample is False
+ is_sample_gen_mode = (num_beams == 1) and do_sample is True
+ is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode:
# run greedy search
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
- return _sample(model,
- input_ids,
- max_length,
- early_stopping=early_stopping,
- eos_token_id=eos_token_id,
- pad_token_id=pad_token_id,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- prepare_inputs_fn=prepare_inputs_fn,
- update_model_kwargs_fn=update_model_kwargs_fn,
- **model_kwargs)
+ return _sample(
+ model,
+ input_ids,
+ max_length,
+ early_stopping=early_stopping,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ prepare_inputs_fn=prepare_inputs_fn,
+ update_model_kwargs_fn=update_model_kwargs_fn,
+ **model_kwargs,
+ )
elif is_beam_gen_mode:
raise NotImplementedError
else:
diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py
index 63dc5ab0f5ea..823cf4a75e0d 100644
--- a/applications/Chat/coati/models/gpt/__init__.py
+++ b/applications/Chat/coati/models/gpt/__init__.py
@@ -2,4 +2,4 @@
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
-__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
+__all__ = ["GPTActor", "GPTCritic", "GPTRM"]
diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py
index ae9d669f1f56..a7e4b9bc3e22 100644
--- a/applications/Chat/coati/models/gpt/gpt_actor.py
+++ b/applications/Chat/coati/models/gpt/gpt_actor.py
@@ -18,13 +18,15 @@ class GPTActor(Actor):
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py
index 01e1cd10ef57..22ab36dea276 100644
--- a/applications/Chat/coati/models/gpt/gpt_critic.py
+++ b/applications/Chat/coati/models/gpt/gpt_critic.py
@@ -18,12 +18,14 @@ class GPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py
index e52a5a14c1da..8edfc4008466 100644
--- a/applications/Chat/coati/models/gpt/gpt_rm.py
+++ b/applications/Chat/coati/models/gpt/gpt_rm.py
@@ -18,11 +18,13 @@ class GPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py
index 9b2a024afdb2..c87d732538a9 100644
--- a/applications/Chat/coati/models/llama/__init__.py
+++ b/applications/Chat/coati/models/llama/__init__.py
@@ -2,4 +2,4 @@
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
-__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
+__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py
index 2c7adb390d8b..f1d9406835ca 100644
--- a/applications/Chat/coati/models/llama/llama_actor.py
+++ b/applications/Chat/coati/models/llama/llama_actor.py
@@ -1,7 +1,6 @@
from typing import Optional
-import torch
-from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
+from transformers import LlamaConfig, LlamaForCausalLM
from ..base import Actor
@@ -18,13 +17,14 @@ class LlamaActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py
index a67e5de5def6..000dce17ccf0 100644
--- a/applications/Chat/coati/models/llama/llama_critic.py
+++ b/applications/Chat/coati/models/llama/llama_critic.py
@@ -17,13 +17,14 @@ class LlamaCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py
index d6b62922686e..43bc9e638dc7 100644
--- a/applications/Chat/coati/models/llama/llama_rm.py
+++ b/applications/Chat/coati/models/llama/llama_rm.py
@@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
-from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
+from transformers import LlamaConfig, LlamaModel
from ..base import RewardModel
@@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py
index f1597da540a7..2114913e107b 100644
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/Chat/coati/models/lora.py
@@ -8,8 +8,7 @@
class LoraLinear(lora.LoRALayer, nn.Module):
- """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
- """
+ """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
@@ -17,16 +16,14 @@ def __init__(
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ lora_dropout: float = 0.0,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
- lora.LoRALayer.__init__(self,
- r=r,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- merge_weights=merge_weights)
+ lora.LoRALayer.__init__(
+ self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
+ )
self.weight = weight
self.bias = bias
@@ -47,13 +44,12 @@ def __init__(
self.weight.data = self.weight.data.T
def reset_parameters(self):
- if hasattr(self, 'lora_A'):
+ if hasattr(self, "lora_A"):
# Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
-
def T(w):
return w.T if self.fan_in_fan_out else w
@@ -71,7 +67,6 @@ def T(w):
self.merged = False
def eval(self):
-
def T(w):
return w.T if self.fan_in_fan_out else w
@@ -80,12 +75,11 @@ def T(w):
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- delattr(self, 'lora_A')
- delattr(self, 'lora_B')
+ delattr(self, "lora_A")
+ delattr(self, "lora_B")
self.merged = True
def forward(self, x: torch.Tensor):
-
def T(w):
return w.T if self.fan_in_fan_out else w
@@ -99,7 +93,9 @@ def T(w):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
- assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
+ assert (
+ lora_rank <= linear.in_features
+ ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
@@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
_convert_to_lora_recursively(child, lora_rank)
-def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
+def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
@@ -140,7 +136,7 @@ class LoRAModule(nn.Module):
Defaults to 'none'.
"""
- def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py
index 05a0b4821797..4ad4f4dcd275 100644
--- a/applications/Chat/coati/models/loss.py
+++ b/applications/Chat/coati/models/loss.py
@@ -31,11 +31,13 @@ def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
- def forward(self,
- log_probs: torch.Tensor,
- old_log_probs: torch.Tensor,
- advantages: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(
+ self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
@@ -55,14 +57,16 @@ def __init__(self, clip_eps: float = 0.4) -> None:
super().__init__()
self.clip_eps = clip_eps
- def forward(self,
- values: torch.Tensor,
- old_values: torch.Tensor,
- reward: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(
+ self,
+ values: torch.Tensor,
+ old_values: torch.Tensor,
+ reward: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
- surr1 = (values_clipped - reward)**2
- surr2 = (values - reward)**2
+ surr1 = (values_clipped - reward) ** 2
+ surr2 = (values - reward) ** 2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return 0.5 * loss
diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py
index 334f4df0032a..e37d6e45c8fc 100644
--- a/applications/Chat/coati/models/opt/__init__.py
+++ b/applications/Chat/coati/models/opt/__init__.py
@@ -2,4 +2,4 @@
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
-__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
+__all__ = ["OPTActor", "OPTCritic", "OPTRM"]
diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py
index c14e4377ffb2..cd8908e13fb8 100644
--- a/applications/Chat/coati/models/opt/opt_actor.py
+++ b/applications/Chat/coati/models/opt/opt_actor.py
@@ -18,12 +18,14 @@ class OPTActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py
index f66c4173fa52..f37d28812c27 100644
--- a/applications/Chat/coati/models/opt/opt_critic.py
+++ b/applications/Chat/coati/models/opt/opt_critic.py
@@ -18,12 +18,14 @@ class OPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py
index 6f75344e6aae..893708344ad4 100644
--- a/applications/Chat/coati/models/opt/opt_rm.py
+++ b/applications/Chat/coati/models/opt/opt_rm.py
@@ -17,11 +17,13 @@ class OPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py
index 97637d3523b0..def6190dd71c 100644
--- a/applications/Chat/coati/models/utils.py
+++ b/applications/Chat/coati/models/utils.py
@@ -4,9 +4,9 @@
import torch.nn.functional as F
-def _compute_approx_kl(log_probs: torch.Tensor,
- log_probs_base: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+def _compute_approx_kl(
+ log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
+) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
@@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
return approx_kl
-def compute_reward(r: Union[torch.Tensor, float],
- kl_coef: float,
- log_probs: torch.Tensor,
- log_probs_base: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+def compute_reward(
+ r: Union[torch.Tensor, float],
+ kl_coef: float,
+ log_probs: torch.Tensor,
+ log_probs_base: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
@@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
Returns:
torch.Tensor: Action log probs.
"""
- logits = output['logits']
+ logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
diff --git a/applications/Chat/coati/quant/__init__.py b/applications/Chat/coati/quant/__init__.py
index a65a78d07bb8..1765b8091bc3 100644
--- a/applications/Chat/coati/quant/__init__.py
+++ b/applications/Chat/coati/quant/__init__.py
@@ -2,6 +2,6 @@
from .utils import low_resource_init
__all__ = [
- 'llama_load_quant',
- 'low_resource_init',
+ "llama_load_quant",
+ "low_resource_init",
]
diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/Chat/coati/quant/llama_gptq/__init__.py
index 51c8d6316290..51d5233586ad 100644
--- a/applications/Chat/coati/quant/llama_gptq/__init__.py
+++ b/applications/Chat/coati/quant/llama_gptq/__init__.py
@@ -1,5 +1,5 @@
from .loader import load_quant
__all__ = [
- 'load_quant',
+ "load_quant",
]
diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/Chat/coati/quant/llama_gptq/loader.py
index 5353dc8a2ea3..50486337a7ab 100644
--- a/applications/Chat/coati/quant/llama_gptq/loader.py
+++ b/applications/Chat/coati/quant/llama_gptq/loader.py
@@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
# ignore lm head
layers = find_layers(model)
- for name in ['lm_head']:
+ for name in ["lm_head"]:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
- if checkpoint.endswith('.safetensors'):
+ if checkpoint.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load
+
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/Chat/coati/quant/llama_gptq/model_utils.py
index 62db171abb52..18e4e4761500 100644
--- a/applications/Chat/coati/quant/llama_gptq/model_utils.py
+++ b/applications/Chat/coati/quant/llama_gptq/model_utils.py
@@ -1,13 +1,12 @@
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
-import torch
import torch.nn as nn
-def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
+def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
- res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
+ res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
return res
diff --git a/applications/Chat/coati/quant/llama_gptq/quant.py b/applications/Chat/coati/quant/llama_gptq/quant.py
index f7d5b7ce4bd8..5a7e2e72dfc5 100644
--- a/applications/Chat/coati/quant/llama_gptq/quant.py
+++ b/applications/Chat/coati/quant/llama_gptq/quant.py
@@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
class Quantizer(nn.Module):
-
def __init__(self, shape=1):
super(Quantizer, self).__init__()
- self.register_buffer('maxq', torch.tensor(0))
- self.register_buffer('scale', torch.zeros(shape))
- self.register_buffer('zero', torch.zeros(shape))
+ self.register_buffer("maxq", torch.tensor(0))
+ self.register_buffer("scale", torch.zeros(shape))
+ self.register_buffer("zero", torch.zeros(shape))
- def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
+ def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
@@ -68,7 +67,7 @@ def find_params(self, x, weight=False):
self.zero = torch.round(-xmin / self.scale)
if self.mse:
- best = torch.full([x.shape[0]], float('inf'), device=dev)
+ best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
@@ -123,13 +122,12 @@ def ready(self):
try:
import quant_cuda
except:
- print('CUDA extension not installed.')
+ print("CUDA extension not installed.")
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
-
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
@@ -142,11 +140,11 @@ def __init__(self, bits, groupsize, infeatures, outfeatures):
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
- 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
- dtype=torch.int))
- self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
- self.register_buffer('bias', torch.zeros(outfeatures))
- self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
+ "qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
+ )
+ self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
+ self.register_buffer("bias", torch.zeros(outfeatures))
+ self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
@@ -161,8 +159,10 @@ def pack(self, linear, scales, zeros):
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
- torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
- None])
+ torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
+ :, None
+ ]
+ )
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
@@ -271,13 +271,13 @@ def forward(self, x):
return y.reshape(outshape)
-def make_quant(module, names, bits, groupsize, name=''):
+def make_quant(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
- name1 = name + '.' + attr if name != '' else attr
+ name1 = name + "." + attr if name != "" else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
- make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
+ make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
diff --git a/applications/Chat/coati/quant/utils.py b/applications/Chat/coati/quant/utils.py
index 01b8cff0add1..d102bb30f52d 100644
--- a/applications/Chat/coati/quant/utils.py
+++ b/applications/Chat/coati/quant/utils.py
@@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
@contextmanager
def low_resource_init():
- """This context manager disables weight initialization and sets the default float dtype to half.
- """
+ """This context manager disables weight initialization and sets the default float dtype to half."""
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
old_uniform_ = torch.nn.init.uniform_
old_normal_ = torch.nn.init.normal_
diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/Chat/coati/ray/callbacks/base.py
index 3306150a41ff..8c5bd8a67776 100644
--- a/applications/Chat/coati/ray/callbacks/base.py
+++ b/applications/Chat/coati/ray/callbacks/base.py
@@ -5,7 +5,7 @@
class TrainerCallback(ABC):
"""
- Base callback class. It defines the interface for callbacks.
+ Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
@@ -40,7 +40,6 @@ def on_update_end(self) -> None:
class MakerCallback(ABC):
-
def on_loop_start(self) -> None:
pass
diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
index d3df8f9ae3e0..18798bce7dce 100644
--- a/applications/Chat/coati/ray/callbacks/performance_evaluator.py
+++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
@@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
-
def __init__(self) -> None:
self.start_time: Optional[float] = None
- self.duration: float = 0.
+ self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@@ -42,13 +41,13 @@ def end(self) -> None:
self.duration += time() - self.start_time
def reset(self) -> None:
- self.duration = 0.
+ self.duration = 0.0
class ExperienceMakerPerformanceEvaluator(MakerCallback):
-
- def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
- reward_model_num_params: int) -> None:
+ def __init__(
+ self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
+ ) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -63,7 +62,7 @@ def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_
self.make_experience_flop: int = 0
print_rank_0(
- f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
+ f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_make_experience_start(self) -> None:
@@ -110,27 +109,29 @@ def on_loop_end(self) -> None:
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
- avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
- (self.total_samples * self.world_size)
+ avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
+ self.total_samples * self.world_size
+ )
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
- 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
- + f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
- + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
- + f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
-
- + f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ "Making Experience Performance Summary:\n"
+ + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
+ + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ + f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
class TrainerPerformanceEvaluator(TrainerCallback):
-
- def __init__(self,
- actor_num_params: int,
- critic_num_params: int,
- enable_grad_checkpoint: bool = False,
- ignore_first_episodes: int = 1) -> None:
+ def __init__(
+ self,
+ actor_num_params: int,
+ critic_num_params: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_first_episodes: int = 1,
+ ) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -146,7 +147,7 @@ def __init__(self,
self.learn_flop: int = 0
print_rank_0(
- f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
+ f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_episode_start(self, episodes: int) -> None:
@@ -191,7 +192,7 @@ def on_update_end(self) -> None:
def on_fit_end(self) -> None:
if self.total_samples == 0:
- print_rank_0('No samples are collected, skip trainer performance evaluation')
+ print_rank_0("No samples are collected, skip trainer performance evaluation")
return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
@@ -204,9 +205,10 @@ def on_fit_end(self) -> None:
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
- 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
- + f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
- + f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
-
- + f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ "Learning Performance Summary:\n"
+ + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
+ + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ + f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py
index e04bf5ccb881..92dab17292f7 100644
--- a/applications/Chat/coati/ray/detached_replay_buffer.py
+++ b/applications/Chat/coati/ray/detached_replay_buffer.py
@@ -1,20 +1,15 @@
-import asyncio
-import copy
-import random
-from threading import Lock
-from typing import Any, List
+from typing import List
-import ray
import torch
-from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
+
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
class DetachedReplayBuffer:
- '''
+ """
Detached replay buffer. Share Experience across workers on the same node.
Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
@@ -24,7 +19,7 @@ class DetachedReplayBuffer:
tp_world_size: Number of workers in the same tp group
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
- '''
+ """
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
@@ -34,23 +29,23 @@ def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
@torch.no_grad()
def append(self, experience: Experience) -> None:
- '''
+ """
Expected to be called remotely.
- '''
+ """
items = split_experience_batch(experience)
self.extend(items)
@torch.no_grad()
def extend(self, items: List[BufferItem]) -> None:
- '''
+ """
Expected to be called remotely.
- '''
+ """
self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size:
- items = self.batch_collector[:self.sample_batch_size]
+ items = self.batch_collector[: self.sample_batch_size]
experience = make_experience_batch(items)
self.items.put(experience, block=True)
- self.batch_collector = self.batch_collector[self.sample_batch_size:]
+ self.batch_collector = self.batch_collector[self.sample_batch_size :]
def clear(self) -> None:
# self.items.close()
diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py
index 90399781187a..fcf0a472df9e 100644
--- a/applications/Chat/coati/ray/detached_trainer_base.py
+++ b/applications/Chat/coati/ray/detached_trainer_base.py
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, Iterable, List, Optional, Union
+from typing import Any, Dict, List
import ray
import torch
@@ -15,7 +15,7 @@
class DetachedTrainer(ABC):
- '''
+ """
Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init:
@@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
- '''
-
- def __init__(self,
- experience_maker_holder_name_list: List[str],
- train_batch_size: int = 8,
- buffer_limit: int = 0,
- dataloader_pin_memory: bool = True,
- callbacks: List[TrainerCallback] = [],
- debug: bool = False) -> None:
+ """
+
+ def __init__(
+ self,
+ experience_maker_holder_name_list: List[str],
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ dataloader_pin_memory: bool = True,
+ callbacks: List[TrainerCallback] = [],
+ debug: bool = False,
+ ) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
@@ -67,18 +69,16 @@ def training_step(self, experience: Experience) -> Dict[str, Any]:
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
- pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
+ pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(0)
self._learn_epoch(pbar, data)
self._on_epoch_end(0)
# item is already a batch
- dataloader = DataLoader(data,
- batch_size=1,
- shuffle=True,
- pin_memory=self.dataloader_pin_memory,
- collate_fn=lambda x: x[0])
+ dataloader = DataLoader(
+ data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
+ )
for epoch in range(1, train_epochs):
- pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
+ pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
@@ -104,7 +104,7 @@ def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
- for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
+ for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
self._on_episode_start(i)
self._learn(update_steps, train_epochs)
self._on_update_start()
diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py
index 2f2aa0e29579..ef84a1ddba48 100644
--- a/applications/Chat/coati/ray/detached_trainer_ppo.py
+++ b/applications/Chat/coati/ray/detached_trainer_ppo.py
@@ -1,12 +1,11 @@
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Tuple
import ray
import torch
-from coati.experience_maker import Experience, NaiveExperienceMaker
+from coati.experience_maker import Experience
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
-from coati.trainer.callbacks import Callback
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
+from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
@@ -14,27 +13,14 @@
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
-from .utils import (
- get_actor_from_args,
- get_critic_from_args,
- get_model_numel,
- get_rank,
- get_strategy_from_args,
- is_rank_0,
- set_dist_env,
- state_dict_to,
-)
+from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
-@ray.remote(concurrency_groups={
- "buffer_length": 1,
- "buffer_append": 1,
- "buffer_sample": 1,
- "model_io": 1,
- "compute": 1
-})
+@ray.remote(
+ concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
+)
class DetachedPPOTrainer(DetachedTrainer):
- '''
+ """
Detached Trainer for PPO algorithm
Args:
strategy (Strategy): the strategy to use for training
@@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
- '''
+ """
def __init__(
self,
@@ -92,21 +78,24 @@ def __init__(
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
- (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
- self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
+ (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
+ (self.actor, self.actor_optim), (self.critic, self.critic_optim)
+ )
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
- super().__init__(experience_maker_holder_name_list,
- train_batch_size=train_batch_size,
- buffer_limit=buffer_limit,
- dataloader_pin_memory=dataloader_pin_memory,
- callbacks=callbacks,
- debug=debug)
+ super().__init__(
+ experience_maker_holder_name_list,
+ train_batch_size=train_batch_size,
+ buffer_limit=buffer_limit,
+ dataloader_pin_memory=dataloader_pin_memory,
+ callbacks=callbacks,
+ debug=debug,
+ )
if self._debug:
- print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
+ print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
self._update_lora_weights = update_lora_weights
@@ -115,7 +104,7 @@ def __init__(
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if not fully_update:
- config['requires_grad_only'] = True
+ config["requires_grad_only"] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
@@ -131,7 +120,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config):
target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
- fully_update=fully_update))
+ fully_update=fully_update,
+ )
+ )
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
@@ -139,7 +130,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config):
target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
- fully_update=fully_update))
+ fully_update=fully_update,
+ )
+ )
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
@@ -152,26 +145,24 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
- actor_loss = self.actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ actor_loss = self.actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
- values = self.critic(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self.critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self.critic(
+ experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
+ )
+ critic_loss = self.critic_loss_fn(
+ values, experience.values, experience.reward, action_mask=experience.action_mask
+ )
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
- return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
+ return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.actor, path, only_rank0)
diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py
index 13314bdafd5f..4d290f4aba88 100644
--- a/applications/Chat/coati/ray/experience_maker_holder.py
+++ b/applications/Chat/coati/ray/experience_maker_holder.py
@@ -1,53 +1,49 @@
import os
import time
import tracemalloc
-from copy import deepcopy
from threading import Lock
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import ray
import torch
-import torch.nn as nn
-from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
-from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
+from coati.experience_buffer.utils import split_experience_batch
+from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
-from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
-from coati.trainer.strategies.sampler import DistributedSampler
-from ray.exceptions import GetTimeoutError
from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .lora_constructor import LoRAConstructor
-from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
+from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
- '''
+ """
Args:
detached_trainer_name_list: str list to get ray actor handles
strategy:
kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
- '''
+ """
def __init__(
- self,
- detached_trainer_name_list: List[str],
- strategy_fn: Callable[[], Strategy],
+ self,
+ detached_trainer_name_list: List[str],
+ strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
- model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
- env_info: Dict[str, str] = None,
- sync_models_from_trainers: bool = False,
- buffer_cpu_offload: bool = True,
- kl_coef: float = 0.1,
- callbacks: List[MakerCallback] = [],
- eval_performance: bool = False,
- debug: bool = False,
- update_lora_weights: bool = False,
- **generate_kwargs):
+ model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
+ env_info: Dict[str, str] = None,
+ sync_models_from_trainers: bool = False,
+ buffer_cpu_offload: bool = True,
+ kl_coef: float = 0.1,
+ callbacks: List[MakerCallback] = [],
+ eval_performance: bool = False,
+ debug: bool = False,
+ update_lora_weights: bool = False,
+ **generate_kwargs,
+ ):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
@@ -66,8 +62,9 @@ def __init__(
critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model)
- evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
- reward_model_numel)
+ evaluator = ExperienceMakerPerformanceEvaluator(
+ actor_numel, critic_numel, initial_model_numel, reward_model_numel
+ )
callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
@@ -89,9 +86,9 @@ def __init__(
self._target_idx = 0
if self._debug:
- print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
+ print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
if not self._is_fully_initialized:
- print(f'[maker{get_rank()}] Waiting for INIT')
+ print(f"[maker{get_rank()}] Waiting for INIT")
def _get_ready(self):
while not self._fully_initialized():
@@ -136,7 +133,7 @@ def _inference_step(self, batch) -> None:
self._on_make_experience_end(experience)
self._on_send_start()
if self.buffer_cpu_offload:
- experience.to_device('cpu')
+ experience.to_device("cpu")
self._send_items(experience)
self._on_send_end()
self._on_batch_end()
@@ -155,7 +152,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1
if num_steps > 0:
# ignore num epochs
it = iter(dataloader)
- for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
+ for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
try:
batch = next(it)
except StopIteration:
@@ -163,7 +160,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1
batch = next(it)
self._inference_step(batch)
else:
- with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
+ with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
for _ in range(num_epochs):
for batch in dataloader:
self._inference_step(batch)
@@ -171,22 +168,24 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1
self._on_loop_end()
@ray.method(concurrency_group="model_io")
- def update_experience_maker(self,
- new_actor_state_dict: Dict[str, Any] = None,
- new_actor_lora_config_dict: Dict[str, Any] = None,
- new_critic_state_dict: Dict[str, Any] = None,
- new_critic_lora_config_dict: Dict[str, Any] = None,
- fully_update: bool = False,
- chunk_start: bool = None,
- chunk_end: bool = None):
- '''
- called by trainer
- chunk_start: Set True at the first call. Before sending state_dict calls
- chunk_end: Set True at the last call. After sending state_dict calls.
- fully_update: Set True if you want to sync models when initializing
-
- TODO: load_state_dict integrate with model-sharding strategy
- '''
+ def update_experience_maker(
+ self,
+ new_actor_state_dict: Dict[str, Any] = None,
+ new_actor_lora_config_dict: Dict[str, Any] = None,
+ new_critic_state_dict: Dict[str, Any] = None,
+ new_critic_lora_config_dict: Dict[str, Any] = None,
+ fully_update: bool = False,
+ chunk_start: bool = None,
+ chunk_end: bool = None,
+ ):
+ """
+ called by trainer
+ chunk_start: Set True at the first call. Before sending state_dict calls
+ chunk_end: Set True at the last call. After sending state_dict calls.
+ fully_update: Set True if you want to sync models when initializing
+
+ TODO: load_state_dict integrate with model-sharding strategy
+ """
_watch_memory = self._debug
if chunk_start:
if self._debug:
@@ -202,18 +201,22 @@ def update_experience_maker(self,
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
- new_actor_state_dict, new_actor_lora_config_dict)
+ new_actor_state_dict, new_actor_lora_config_dict
+ )
self.actor_lora_constructor.load_state_dict_increase(
- self.experience_maker.actor.model, state_dict_increase)
+ self.experience_maker.actor.model, state_dict_increase
+ )
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
- new_critic_state_dict, new_critic_lora_config_dict)
+ new_critic_state_dict, new_critic_lora_config_dict
+ )
self.critic_lora_constructor.load_state_dict_increase(
- self.experience_maker.critic, state_dict_increase)
+ self.experience_maker.critic, state_dict_increase
+ )
# the lock must be released after both actor and critic being updated
if chunk_end:
@@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
- if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
- new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
+ if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
- if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
- new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
+ if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
return new_kwargs
diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py
index a98545d4d751..8e9f78700e29 100644
--- a/applications/Chat/coati/ray/lora_constructor.py
+++ b/applications/Chat/coati/ray/lora_constructor.py
@@ -1,11 +1,9 @@
from collections import OrderedDict
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Dict
-import torch
import torch.nn as nn
from coati.models.lora import LoraLinear
-from loralib.layers import LoRALayer
@dataclass
@@ -17,7 +15,7 @@ class LoRAConfig:
class LoRAConstructor:
- '''
+ """
Tools for reconstructing a model from a remote LoRA model.
(Transferring only LoRA data costs much less!)
Usage:
@@ -36,7 +34,7 @@ class LoRAConstructor:
Step 5 (Receiver):
load_state_dict_increase()
- '''
+ """
def __init__(self):
self.lora_config_dict = None
@@ -45,10 +43,10 @@ def register_lora_config(self, lora_config_dict: Dict[str, Any]):
self.lora_config_dict = lora_config_dict
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
- '''
- xxx.lora_A, xxx.lora_B -->> xxx.weight
- Warning: the xxx.weight here is the increment actually.
- '''
+ """
+ xxx.lora_A, xxx.lora_B -->> xxx.weight
+ Warning: the xxx.weight here is the increment actually.
+ """
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)
@@ -56,24 +54,25 @@ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
- if k.rpartition('.')[-1] == 'lora_A':
+ if k.rpartition(".")[-1] == "lora_A":
lora_A = v
- layer_prefix = k.rpartition('.')[0]
- elif k.rpartition('.')[-1] == 'lora_B':
- assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
+ layer_prefix = k.rpartition(".")[0]
+ elif k.rpartition(".")[-1] == "lora_B":
+ assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
- state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
+ state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
- raise ValueError('unexpected key')
+ raise ValueError("unexpected key")
return state_dict_increase
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
return w.T if config.fan_in_fan_out else w
+
if config.r > 0:
scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling
@@ -81,21 +80,21 @@ def T(w):
return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
- '''
+ """
The final reconstruction step
- '''
+ """
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
- '''
+ """
if keep_non_lora, also return non_lora state_dict
- '''
+ """
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict.items():
- if 'lora_A' in k or 'lora_B' in k:
+ if "lora_A" in k or "lora_B" in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
@@ -106,17 +105,19 @@ def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
- '''
+ """
extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig
- '''
+ """
lora_config_dict = OrderedDict()
for name, child in model.named_modules():
if isinstance(child, LoraLinear):
- lora_config_dict[name] = LoRAConfig(r=child.r,
- lora_alpha=child.lora_alpha,
- lora_dropout=child.lora_dropout,
- fan_in_fan_out=child.fan_in_fan_out)
+ lora_config_dict[name] = LoRAConfig(
+ r=child.r,
+ lora_alpha=child.lora_alpha,
+ lora_dropout=child.lora_dropout,
+ fan_in_fan_out=child.fan_in_fan_out,
+ )
return lora_config_dict
diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py
index 391ffe7a91a9..036dd145dddb 100644
--- a/applications/Chat/coati/ray/utils.py
+++ b/applications/Chat/coati/ray/utils.py
@@ -1,6 +1,6 @@
import os
from collections import OrderedDict
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Dict
import torch
import torch.distributed as dist
@@ -10,7 +10,7 @@
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
-from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
def is_rank_0() -> bool:
@@ -26,13 +26,13 @@ def get_world_size() -> int:
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
- if model == 'gpt2':
+ if model == "gpt2":
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
- elif model == 'bloom':
+ elif model == "bloom":
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
- elif model == 'opt':
+ elif model == "opt":
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
- elif model == 'llama':
+ elif model == "llama":
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else:
raise ValueError(f'Unsupported actor model "{model}"')
@@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
- if model == 'gpt2':
+ if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
- elif model == 'bloom':
+ elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
- elif model == 'opt':
+ elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
- elif model == 'llama':
+ elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{model}"')
@@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
- if model == 'gpt2':
+ if model == "gpt2":
reward_model = GPTRM(pretrained=pretrained, config=config)
- elif model == 'bloom':
+ elif model == "bloom":
reward_model = BLOOMRM(pretrained=pretrained, config=config)
- elif model == 'opt':
+ elif model == "opt":
reward_model = OPTRM(pretrained=pretrained, config=config)
- elif model == 'llama':
+ elif model == "llama":
reward_model = LlamaRM(pretrained=pretrained, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
@@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
def get_strategy_from_args(strategy: str):
- if strategy == 'ddp':
+ if strategy == "ddp":
strategy_ = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
- elif strategy == 'colossalai_gemini_cpu':
- strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
- elif strategy == 'colossalai_zero2_cpu':
- strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
+ elif strategy == "colossalai_gemini":
+ strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ elif strategy == "colossalai_zero2":
+ strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif strategy == "colossalai_gemini_cpu":
+ strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
+ elif strategy == "colossalai_zero2_cpu":
+ strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def get_tokenizer_from_args(model: str, **kwargs):
- if model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- elif model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
- elif model == 'opt':
+ if model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ elif model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
+ elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- elif model == 'llama':
+ elif model == "llama":
pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
else:
@@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
def set_dist_env(env_info: Dict[str, str]):
- os.environ["RANK"] = env_info['rank']
- os.environ["LOCAL_RANK"] = env_info['local_rank']
- os.environ["WORLD_SIZE"] = env_info['world_size']
- os.environ['MASTER_PORT'] = env_info['master_port']
- os.environ['MASTER_ADDR'] = env_info['master_addr']
+ os.environ["RANK"] = env_info["rank"]
+ os.environ["LOCAL_RANK"] = env_info["local_rank"]
+ os.environ["WORLD_SIZE"] = env_info["world_size"]
+ os.environ["MASTER_PORT"] = env_info["master_port"]
+ os.environ["MASTER_ADDR"] = env_info["master_addr"]
def get_model_numel(model: nn.Module) -> int:
@@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
return target_receivers
-def state_dict_to(state_dict: Dict[str, Any],
- dtype: torch.dtype = torch.float16,
- device: torch.device = torch.device('cpu')):
- '''
- keep state_dict intact
- '''
+def state_dict_to(
+ state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
+):
+ """
+ keep state_dict intact
+ """
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py
index 86142361f3ff..4be5d27f93b1 100644
--- a/applications/Chat/coati/trainer/__init__.py
+++ b/applications/Chat/coati/trainer/__init__.py
@@ -3,8 +3,4 @@
from .rm import RewardModelTrainer
from .sft import SFTTrainer
-__all__ = [
- 'SLTrainer', 'OnPolicyTrainer',
- 'RewardModelTrainer', 'SFTTrainer',
- 'PPOTrainer'
-]
+__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py
index 0629c9c00cca..ca450edee0c3 100644
--- a/applications/Chat/coati/trainer/base.py
+++ b/applications/Chat/coati/trainer/base.py
@@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
- def __init__(self,
- strategy: Strategy,
- data_buffer: NaiveExperienceBuffer,
- sample_buffer: bool,
- dataloader_pin_memory: bool,
- callbacks: List[Callback] = []) -> None:
+ def __init__(
+ self,
+ strategy: Strategy,
+ data_buffer: NaiveExperienceBuffer,
+ sample_buffer: bool,
+ dataloader_pin_memory: bool,
+ callbacks: List[Callback] = [],
+ ) -> None:
super().__init__()
self.strategy = strategy
self.data_buffer = data_buffer
diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py
index 9ed0ee6f7640..29c8c4f00a5c 100644
--- a/applications/Chat/coati/trainer/callbacks/__init__.py
+++ b/applications/Chat/coati/trainer/callbacks/__init__.py
@@ -2,4 +2,4 @@
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
-__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
+__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py
index f5616048855b..d5181175b324 100644
--- a/applications/Chat/coati/trainer/callbacks/base.py
+++ b/applications/Chat/coati/trainer/callbacks/base.py
@@ -5,7 +5,7 @@
class Callback(ABC):
"""
- Base callback class. It defines the interface for callbacks.
+ Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
index 9b44dafa7eaa..c2eda92cc165 100644
--- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
+++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def divide(x: float, y: float) -> float:
if y == 0:
- return float('inf')
- elif y == float('inf'):
- return float('nan')
+ return float("inf")
+ elif y == float("inf"):
+ return float("nan")
return x / y
@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
-
def __init__(self) -> None:
self.start_time: Optional[float] = None
- self.duration: float = 0.
+ self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@@ -52,7 +51,7 @@ def end(self) -> None:
self.start_time = None
def reset(self) -> None:
- self.duration = 0.
+ self.duration = 0.0
class PerformanceEvaluator(Callback):
@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
- def __init__(self,
- actor_num_params: int,
- critic_num_params: int,
- initial_model_num_params: int,
- reward_model_num_params: int,
- enable_grad_checkpoint: bool = False,
- ignore_episodes: int = 0) -> None:
+ def __init__(
+ self,
+ actor_num_params: int,
+ critic_num_params: int,
+ initial_model_num_params: int,
+ reward_model_num_params: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_episodes: int = 0,
+ ) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -155,8 +156,9 @@ def on_fit_end(self) -> None:
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
- avg_make_experience_throughput = self.make_experience_num_samples * \
- self.world_size / (avg_make_experience_duration + 1e-12)
+ avg_make_experience_throughput = (
+ self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
+ )
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
@@ -171,13 +173,11 @@ def on_fit_end(self) -> None:
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
- f'Performance summary:\n'
- + f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
-
- + f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
- + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
- + f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
- + f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
-
- + f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
+ f"Performance summary:\n"
+ + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ + f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ + f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
)
diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
index f0d77a191a88..0d70b6c53073 100644
--- a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
+++ b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
@@ -36,34 +36,35 @@ class SaveCheckpoint(Callback):
"""
- def __init__(self,
- path: str,
- interval: int,
- strategy: Strategy,
- actor: nn.Module = None,
- critic: nn.Module = None,
- actor_optim: Optimizer = None,
- critic_optim: Optimizer = None) -> None:
+ def __init__(
+ self,
+ path: str,
+ interval: int,
+ strategy: Strategy,
+ actor: nn.Module = None,
+ critic: nn.Module = None,
+ actor_optim: Optimizer = None,
+ critic_optim: Optimizer = None,
+ ) -> None:
super().__init__()
- self.path = os.path.join(path, 'checkpoint')
+ self.path = os.path.join(path, "checkpoint")
self.interval = interval
self.strategy = strategy
- self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
+ self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
- base_path = os.path.join(self.path, f'episode_{episode}')
+ base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
-
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
- model_path = os.path.join(base_path, f'{model}.pt')
+ model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
@@ -71,5 +72,5 @@ def on_episode_end(self, episode: int) -> None:
continue
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
- optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
+ optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py
index ef625a1c1b3d..6f255a935d91 100644
--- a/applications/Chat/coati/trainer/ppo.py
+++ b/applications/Chat/coati/trainer/ppo.py
@@ -8,7 +8,7 @@
from coati.models.utils import calc_action_log_probs
from torch import Tensor
from torch.optim import Optimizer
-from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data import DistributedSampler
from tqdm import tqdm
from colossalai.utils import get_current_device
@@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
- if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
- new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
+ if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
- if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
- new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
+ if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
return new_kwargs
@@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer):
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
- def __init__(self,
- strategy: Strategy,
- actor: Actor,
- critic: Critic,
- reward_model: nn.Module,
- initial_model: Actor,
- actor_optim: Optimizer,
- critic_optim: Optimizer,
- kl_coef: float = 0.1,
- ptx_coef: float = 0.9,
- train_batch_size: int = 8,
- buffer_limit: int = 0,
- buffer_cpu_offload: bool = True,
- eps_clip: float = 0.2,
- vf_coef: float = 1.0,
- value_clip: float = 0.4,
- sample_buffer: bool = False,
- dataloader_pin_memory: bool = True,
- offload_inference_models: bool = True,
- callbacks: List[Callback] = [],
- **generate_kwargs
- ) -> None:
+ def __init__(
+ self,
+ strategy: Strategy,
+ actor: Actor,
+ critic: Critic,
+ reward_model: nn.Module,
+ initial_model: Actor,
+ actor_optim: Optimizer,
+ critic_optim: Optimizer,
+ kl_coef: float = 0.1,
+ ptx_coef: float = 0.9,
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ buffer_cpu_offload: bool = True,
+ eps_clip: float = 0.2,
+ vf_coef: float = 1.0,
+ value_clip: float = 0.4,
+ sample_buffer: bool = False,
+ dataloader_pin_memory: bool = True,
+ offload_inference_models: bool = True,
+ callbacks: List[Callback] = [],
+ **generate_kwargs,
+ ) -> None:
if isinstance(strategy, GeminiStrategy):
- assert not offload_inference_models, \
- "GeminiPlugin is not compatible with manual model.to('cpu')"
+ assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
- super().__init__(
- strategy, data_buffer,
- sample_buffer, dataloader_pin_memory,
- callbacks
- )
+ super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
@@ -130,18 +126,16 @@ def _training_step(self, experience: Experience) -> Dict[str, float]:
num_actions = experience.action_mask.size(1)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
- actor_loss = self.actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ actor_loss = self.actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
- ptx_log_probs = self.actor(batch['input_ids'],
- attention_mask=batch['attention_mask'])['logits']
- ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
+ ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
+ ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
@@ -149,24 +143,23 @@ def _training_step(self, experience: Experience) -> Dict[str, float]:
self.actor_optim.zero_grad()
# value loss
- values = self.critic(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self.critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self.critic(
+ experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
+ )
+ critic_loss = self.critic_loss_fn(
+ values, experience.values, experience.reward, action_mask=experience.action_mask
+ )
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
- return {'reward': experience.reward.mean().item()}
+ return {"reward": experience.reward.mean().item()}
def _learn(self, update_step: int):
if self.offload_inference_models:
- self.experience_maker.initial_model.to('cpu')
- self.experience_maker.reward_model.to('cpu')
+ self.experience_maker.initial_model.to("cpu")
+ self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
@@ -178,11 +171,7 @@ def _learn(self, update_step: int):
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
- pbar = tqdm(
- self.dataloader,
- desc=f'Train epoch [{update_step + 1}]',
- disable=not is_rank_0()
- )
+ pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py
index 54a5d0f40dea..a5d6974b3238 100644
--- a/applications/Chat/coati/trainer/rm.py
+++ b/applications/Chat/coati/trainer/rm.py
@@ -62,18 +62,15 @@ def _eval(self, epoch):
if is_rank_0():
log = pd.DataFrame(
- [[(epoch + 1) * len(self.train_dataloader),
- self.loss.item(), self.dist, self.acc]],
- columns=['step', 'loss', 'dist', 'acc']
+ [[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
+ columns=["step", "loss", "dist", "acc"],
)
- log.to_csv('log.csv', mode='a', header=False, index=False)
+ log.to_csv("log.csv", mode="a", header=False, index=False)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
- len(self.train_dataloader),
- desc='Train step of epoch %d' % epoch,
- disable=not is_rank_0()
+ len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
)
cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
@@ -93,10 +90,7 @@ def _train(self, epoch):
step_bar.update()
step_bar.close()
- def _before_fit(self,
- train_dataloader: DataLoader,
- valid_dataloader: DataLoader,
- eval_dataloader: DataLoader):
+ def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
@@ -104,7 +98,7 @@ def _before_fit(self,
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
super()._before_fit()
- self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
+ self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py
index e4d0a970740d..8deefc2c484e 100644
--- a/applications/Chat/coati/trainer/sft.py
+++ b/applications/Chat/coati/trainer/sft.py
@@ -39,8 +39,9 @@ def __init__(
accumulation_steps: int = 8,
) -> None:
if accumulation_steps > 1:
- assert not isinstance(strategy, GeminiStrategy), \
- "Accumulation steps are not supported in stage 3 of ColossalAI"
+ assert not isinstance(
+ strategy, GeminiStrategy
+ ), "Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, model, optim)
@@ -50,15 +51,11 @@ def __init__(
def _train(self, epoch: int):
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
-
batch = to_device(batch, torch.cuda.current_device())
if "attention_mask" in batch:
- outputs = self.model(batch["input_ids"],
- attention_mask=batch["attention_mask"],
- labels=batch["labels"])
+ outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
else:
- outputs = self.model(batch["input_ids"],
- labels=batch["labels"])
+ outputs = self.model(batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss = loss / self.accumulation_steps
@@ -73,12 +70,14 @@ def _train(self, epoch: int):
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and self.use_wandb:
- wandb.log({
- "loss": self.total_loss / self.accumulation_steps,
- "lr": self.scheduler.get_last_lr()[0],
- "epoch": epoch,
- "batch_id": batch_id
- })
+ wandb.log(
+ {
+ "loss": self.total_loss / self.accumulation_steps,
+ "lr": self.scheduler.get_last_lr()[0],
+ "epoch": epoch,
+ "batch_id": batch_id,
+ }
+ )
self.total_loss = 0
self.step_bar.update()
@@ -89,9 +88,9 @@ def _eval(self, epoch: int):
loss_sum, num_seen = 0, 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"],
- attention_mask=batch["attention_mask"],
- labels=batch["labels"])
+ outputs = self.model(
+ batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
+ )
loss = outputs.loss
loss_sum += loss.item()
@@ -99,13 +98,15 @@ def _eval(self, epoch: int):
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
- self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
+ self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
- def _before_fit(self,
- train_dataloader: DataLoader,
- eval_dataloader: Optional[DataLoader] = None,
- logger: Optional[DistributedLogger] = None,
- use_wandb: bool = False):
+ def _before_fit(
+ self,
+ train_dataloader: DataLoader,
+ eval_dataloader: Optional[DataLoader] = None,
+ logger: Optional[DistributedLogger] = None,
+ use_wandb: bool = False,
+ ):
"""
Args:
train_dataloader: the dataloader to use for training
@@ -124,6 +125,6 @@ def _before_fit(self,
self.no_epoch_bar = True
self.step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
- desc=f'steps',
- disable=not is_rank_0()
+ desc=f"steps",
+ disable=not is_rank_0(),
)
diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py
index b49a2c742db3..521dcb5855b1 100644
--- a/applications/Chat/coati/trainer/strategies/__init__.py
+++ b/applications/Chat/coati/trainer/strategies/__init__.py
@@ -2,7 +2,4 @@
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
-__all__ = [
- 'Strategy', 'DDPStrategy',
- 'LowLevelZeroStrategy', 'GeminiStrategy'
-]
+__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py
index c20b2b16e396..303d4bc220a6 100644
--- a/applications/Chat/coati/trainer/strategies/base.py
+++ b/applications/Chat/coati/trainer/strategies/base.py
@@ -19,7 +19,7 @@
class Strategy(ABC):
"""
- Base class for training strategies.
+ Base class for training strategies.
"""
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
@@ -83,16 +83,18 @@ def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _Boo
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
- boost_result = dict(model=model,
- optimizer=optimizer,
- criterion=criterion,
- dataloader=dataloader,
- lr_scheduler=lr_scheduler)
+ boost_result = dict(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ dataloader=dataloader,
+ lr_scheduler=lr_scheduler,
+ )
# remove None values
boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
- raise RuntimeError(f'Type {type(arg)} is not supported')
+ raise RuntimeError(f"Type {type(arg)} is not supported")
return rets[0] if len(rets) == 1 else rets
@@ -125,11 +127,9 @@ def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, 1, 0)
@abstractmethod
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
+ def save_pretrained(
+ self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
+ ) -> None:
pass
@abstractmethod
diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py
index fa55f97ad661..4706f9699c91 100644
--- a/applications/Chat/coati/trainer/strategies/colossalai.py
+++ b/applications/Chat/coati/trainer/strategies/colossalai.py
@@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy):
"""
- def __init__(self,
- stage: int = 2,
- precision: str = 'fp16',
- seed: int = 42,
- placement_policy: str = 'cuda',
- reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
- overlap_communication: bool = True, # only for stage 1&2
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- max_norm: float = 0.0,
- norm_type: float = 2.0
- ) -> None:
-
+ def __init__(
+ self,
+ stage: int = 2,
+ precision: str = "fp16",
+ seed: int = 42,
+ placement_policy: str = "cuda",
+ reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
+ overlap_communication: bool = True, # only for stage 1&2
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ) -> None:
assert stage in (1, 2), f'Unsupported stage "{stage}"'
- assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
- assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
+ assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
+ assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
@@ -71,7 +71,7 @@ def __init__(self,
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
- cpu_offload=(placement_policy == 'cpu'),
+ cpu_offload=(placement_policy == "cpu"),
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
@@ -81,14 +81,15 @@ def __init__(self,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
- norm_type=norm_type
+ norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
- assert isinstance(self.plugin, LowLevelZeroPlugin), \
- f'{type(self).__name__}\'s plugin is not initialized properly.'
+ assert isinstance(
+ self.plugin, LowLevelZeroPlugin
+ ), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
@@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy):
"""
- def __init__(self,
- seed: int = 42,
- shard_init: bool = False, # only for stage 3
- placement_policy: str = 'cuda',
- pin_memory: bool = True, # only for stage 3
- force_outputs_fp32: bool = False, # only for stage 3
- search_range_m: int = 32, # only for stage 3
- hidden_dim: Optional[int] = None, # only for stage 3
- min_chunk_size_m: float = 32, # only for stage 3
- gpu_margin_mem_ratio: float = 0.0, # only for stage 3
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- max_norm: float = 0.0,
- norm_type: float = 2.0
- ) -> None:
-
- assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
+ def __init__(
+ self,
+ seed: int = 42,
+ shard_init: bool = False, # only for stage 3
+ placement_policy: str = "cuda",
+ pin_memory: bool = True, # only for stage 3
+ force_outputs_fp32: bool = False, # only for stage 3
+ search_range_m: int = 32, # only for stage 3
+ hidden_dim: Optional[int] = None, # only for stage 3
+ min_chunk_size_m: float = 32, # only for stage 3
+ gpu_margin_mem_ratio: float = 0.0, # only for stage 3
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ) -> None:
+ assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
- f'Shard init is not supported model.from_pretrained() yet. '
- 'Please load weights after strategy.prepare()'
+ f"Shard init is not supported model.from_pretrained() yet. "
+ "Please load weights after strategy.prepare()"
)
self.shard_init = shard_init
- warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
+ warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
- precision='fp16',
+ precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
@@ -187,14 +188,13 @@ def __init__(self,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
- norm_type=norm_type
+ norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
- assert isinstance(self.plugin, GeminiPlugin), \
- f'{type(self).__name__}\'s plugin is not initialized properly.'
+ assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
@@ -203,10 +203,9 @@ def model_init_context(self):
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
- return ColoInitContext(device=get_current_device(),
- dtype=torch.half,
- default_pg=shard_pg,
- default_dist_spec=default_dist_spec)
+ return ColoInitContext(
+ device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
+ )
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiModel)
diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py
index a52b0460daa8..66ff6703da4d 100644
--- a/applications/Chat/coati/trainer/strategies/ddp.py
+++ b/applications/Chat/coati/trainer/strategies/ddp.py
@@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module):
class DDPStrategy(Strategy):
"""
- Strategy for distributed training using torch.distributed.
+ Strategy for distributed training using torch.distributed.
"""
- def __init__(self,
- seed: int = 42,
- plugin_initializer: Callable = TorchDDPPlugin
- ) -> None:
+ def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
self.seed = seed
super().__init__(plugin_initializer)
def _try_init_dist(self, force: bool = False) -> None:
try:
- rank = int(os.environ['RANK'])
- local_rank = int(os.environ['LOCAL_RANK'])
- world_size = int(os.environ['WORLD_SIZE'])
- host = os.environ['MASTER_ADDR']
- port = int(os.environ['MASTER_PORT'])
- dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ host = os.environ["MASTER_ADDR"]
+ port = int(os.environ["MASTER_PORT"])
+ dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
if force:
@@ -60,8 +57,7 @@ def _try_init_dist(self, force: bool = False) -> None:
raise e
def _post_init(self) -> None:
- assert isinstance(self.plugin, TorchDDPPlugin), \
- f'{type(self).__name__}\'s plugin is not initialized properly.'
+ assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
self._try_init_dist(force=True)
@@ -73,12 +69,14 @@ def set_seed(self, seed: int) -> None:
torch.manual_seed(seed)
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
- return self.plugin.prepare_dataloader(data_buffer,
- batch_size=data_buffer.sample_batch_size,
- shuffle=True,
- drop_last=True,
- pin_memory=pin_memory,
- collate_fn=data_buffer.collate_fn)
+ return self.plugin.prepare_dataloader(
+ data_buffer,
+ batch_size=data_buffer.sample_batch_size,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=pin_memory,
+ collate_fn=data_buffer.collate_fn,
+ )
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
@@ -88,11 +86,9 @@ def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
return model.unwrap()
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
+ def save_pretrained(
+ self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
+ ) -> None:
if not only_rank0 or dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
@@ -103,17 +99,11 @@ def save_pretrained(self,
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
- self.save_model(model,
- model_path,
- only_rank0=only_rank0)
+ self.save_model(model, model_path, only_rank0=only_rank0)
- def _replace_keys(model_path: str,
- replace_fn: Callable):
+ def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
- state_dict = {
- replace_fn(k): v
- for k, v in state_dict.items()
- }
+ state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
@@ -124,13 +114,13 @@ def _replace_keys(model_path: str,
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
- if 'requires_grad_only' in config and config['requires_grad_only'] == True:
+ if "requires_grad_only" in config and config["requires_grad_only"] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()
- if 'shard_size' in config:
- shard_size = config['shard_size']
+ if "shard_size" in config:
+ shard_size = config["shard_size"]
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():
diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py
index d726fa640fa2..6e811bef11a5 100644
--- a/applications/Chat/coati/trainer/strategies/sampler.py
+++ b/applications/Chat/coati/trainer/strategies/sampler.py
@@ -4,7 +4,6 @@
class DistributedSampler:
-
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset
self.num_replicas = num_replicas
@@ -12,7 +11,7 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None:
if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil(
- (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
@@ -20,10 +19,10 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
- indices = indices[:self.total_size]
+ indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices
diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py
index 7e2cb9c634f7..7811e7365eeb 100644
--- a/applications/Chat/coati/trainer/utils.py
+++ b/applications/Chat/coati/trainer/utils.py
@@ -42,7 +42,6 @@ def is_rank_0() -> bool:
def to_device(x: Any, device: torch.device) -> Any:
-
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
diff --git a/applications/Chat/evaluate/config/config_cn.json b/applications/Chat/evaluate/config/config_cn.json
index 023f16bef31c..4d30d005df30 100644
--- a/applications/Chat/evaluate/config/config_cn.json
+++ b/applications/Chat/evaluate/config/config_cn.json
@@ -70,7 +70,7 @@
"BLEU",
"ROUGE",
"BERTScore"
- ]
+ ]
},
"logical_reasoning": {
"GPT": [
@@ -83,7 +83,7 @@
"ROUGE",
"BERTScore",
"CHRF"
- ]
+ ]
},
"open_qa": {
"GPT": [
@@ -126,7 +126,7 @@
"conciseness"
],
"Metrics": [
- ]
+ ]
},
"Finance": {
"GPT": [
@@ -134,7 +134,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"Law": {
"GPT": [
@@ -142,7 +142,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"Education": {
"GPT": [
@@ -150,7 +150,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"Medical": {
"GPT": [
@@ -158,7 +158,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"STEM": {
"GPT": [
@@ -166,7 +166,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"SocialScience": {
"GPT": [
@@ -174,7 +174,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"Humanity": {
"GPT": [
@@ -182,7 +182,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"Other": {
"GPT": [
@@ -190,7 +190,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
},
"ethics": {
"GPT": [
@@ -198,7 +198,7 @@
"correctness"
],
"Metrics": [
- ]
+ ]
}
}
}
diff --git a/applications/Chat/evaluate/eval.py b/applications/Chat/evaluate/eval.py
index e3fe0e9e091b..16ef31a94175 100644
--- a/applications/Chat/evaluate/eval.py
+++ b/applications/Chat/evaluate/eval.py
@@ -1,5 +1,4 @@
import argparse
-import json
import os
import openai
@@ -9,7 +8,8 @@
def main(args):
assert len(args.answer_file_list) == len(
- args.model_name_list), "The number of answer files and model names should be equal!"
+ args.model_name_list
+ ), "The number of answer files and model names should be equal!"
# load config
config = jload(args.config_file)
@@ -36,7 +36,8 @@ def main(args):
if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
raise Exception(
- "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!")
+ "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!"
+ )
if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
raise Exception(
@@ -44,8 +45,15 @@ def main(args):
)
# initialize evaluator
- evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model,
- config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference)
+ evaluator = Evaluator(
+ metrics_per_category,
+ battle_prompt,
+ gpt_evaluation_prompt,
+ args.gpt_model,
+ config["language"],
+ config.get("path_for_UniEval", None),
+ args.gpt_with_reference,
+ )
if len(args.model_name_list) == 2:
answers1 = jload(args.answer_file_list[0])
answers2 = jload(args.answer_file_list[1])
@@ -68,41 +76,41 @@ def main(args):
raise ValueError(f'Unsupported language {config["language"]}!')
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.')
- parser.add_argument('--config_file',
- type=str,
- default=None,
- required=True,
- help='path to the file of target results')
- parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle')
- parser.add_argument('--gpt_evaluation_prompt_file',
- type=str,
- default=None,
- help='path to the prompt file for gpt evaluation')
- parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file')
- parser.add_argument('--answer_file_list',
- type=str,
- nargs='+',
- default=[],
- required=True,
- help='path to the answer files of at most 2 models')
- parser.add_argument('--model_name_list',
- type=str,
- nargs='+',
- default=[],
- required=True,
- help='the names of at most 2 models')
- parser.add_argument('--gpt_model',
- default="gpt-3.5-turbo",
- choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
- help='which GPT model to use for evaluation')
- parser.add_argument('--gpt_with_reference',
- default=False,
- action="store_true",
- help='whether to include reference answer in gpt evaluation')
- parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results')
- parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.")
+ parser.add_argument(
+ "--config_file", type=str, default=None, required=True, help="path to the file of target results"
+ )
+ parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle")
+ parser.add_argument(
+ "--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation"
+ )
+ parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file")
+ parser.add_argument(
+ "--answer_file_list",
+ type=str,
+ nargs="+",
+ default=[],
+ required=True,
+ help="path to the answer files of at most 2 models",
+ )
+ parser.add_argument(
+ "--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models"
+ )
+ parser.add_argument(
+ "--gpt_model",
+ default="gpt-3.5-turbo",
+ choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
+ help="which GPT model to use for evaluation",
+ )
+ parser.add_argument(
+ "--gpt_with_reference",
+ default=False,
+ action="store_true",
+ help="whether to include reference answer in gpt evaluation",
+ )
+ parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results")
+ parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key")
args = parser.parse_args()
if args.openai_key is not None:
diff --git a/applications/Chat/evaluate/evaluator.py b/applications/Chat/evaluate/evaluator.py
index 3dd5fd6f2f23..1d998cd2d09c 100644
--- a/applications/Chat/evaluate/evaluator.py
+++ b/applications/Chat/evaluate/evaluator.py
@@ -3,20 +3,27 @@
import gpt_evaluate
import metrics
-import pandas as pd
import unieval
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
class Evaluator(object):
"""
- A class named Evaluator includes GPT-3.5/GPT-4 evaluation
- and automatic evaluation
+ A class named Evaluator includes GPT-3.5/GPT-4 evaluation
+ and automatic evaluation
"""
- def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any],
- gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None:
+ def __init__(
+ self,
+ params: Dict[str, Any],
+ battle_prompt: Dict[str, Any],
+ gpt_evaluation_prompt: Dict[str, Any],
+ gpt_model: str,
+ language: str,
+ path_for_UniEval: Dict[str, str],
+ gpt_with_reference: bool,
+ ) -> None:
self.params = params
self.battle_prompt = battle_prompt
self.gpt_evaluation_prompt = gpt_evaluation_prompt
@@ -103,7 +110,8 @@ def switch(metric, language):
if self.params[category]["UniEval"] and self.language == "cn":
raise Exception(
- "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.")
+ "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file."
+ )
category_metrics = self.params[category]["UniEval"]
@@ -134,10 +142,9 @@ def switch(metric, language):
sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
- scores = uni_evaluator.evaluate(data,
- category,
- dims=list(self.unieval_metric_stats[task][category].keys()),
- overall=False)
+ scores = uni_evaluator.evaluate(
+ data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False
+ )
avg_scores = unieval.calculate_average_score(scores)
self.unieval_metric_stats[task][category].update(avg_scores)
@@ -165,7 +172,8 @@ def switch(metric, language):
category,
self.gpt_model,
self.language,
- references=targets_per_category[category] if self.gpt_with_reference else None)
+ references=targets_per_category[category] if self.gpt_with_reference else None,
+ )
def save(self, path: str, model_name_list: List[str]) -> None:
"""
@@ -204,16 +212,18 @@ def save(self, path: str, model_name_list: List[str]) -> None:
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
- all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0],
- self.gpt_evaluation_results,
- gpt_evaluation_results_save_path)
+ all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
+ model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
+ )
# Start to calculate scores and save statistics.
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
- gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
- gpt_evaluation_statistics_save_path)
+ gpt_evaluate.save_gpt_evaluation_statistics(
+ model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
+ )
# Save charts and csv.
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
- gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
- gpt_evaluation_analyses_save_path)
+ gpt_evaluate.analyze_gpt_evaluation_statistics(
+ gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
+ )
diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/Chat/evaluate/gpt_evaluate.py
index 6fcbe63d0253..ad908f4ba48c 100644
--- a/applications/Chat/evaluate/gpt_evaluate.py
+++ b/applications/Chat/evaluate/gpt_evaluate.py
@@ -14,20 +14,18 @@
from utils import jdump, jload
ref_step_template = {
- "en":
- "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
- "cn":
- "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n"
+ "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
+ "cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n",
}
ref_answer_template_general = {
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
- "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n"
+ "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
}
ref_answer_template_correctness = {
"en": "\nA correct answer is as follows:\n\n{answer}\n\n",
- "cn": "\n标准答案如下:\n\n{answer}\n\n"
+ "cn": "\n标准答案如下:\n\n{answer}\n\n",
}
@@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
- {
- "role": "system",
- "content": sys_prompt
- },
+ {"role": "system", "content": sys_prompt},
{
"role": "user",
"content": user_prompt,
@@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]:
return [float(sp[0]), float(sp[1])]
else:
raise Exception(f"Invalid score pair. Got {evaluation}.")
- except Exception as e:
+ except Exception:
return [-1, -1]
@@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert len(answer1) == len(answer2)
- handles = []
- evaluation_file = []
-
total_len = len(answer1)
question_idx_list = list(range(total_len))
@@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert answer1[i]["id"] == answer2[i]["id"]
answer_id = answer1[i]["id"]
- ques = answer1[i]["instruction"] if answer1[i][
- "input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"]
- cat = answer1[i]["category"]
+ ques = (
+ answer1[i]["instruction"]
+ if answer1[i]["input"] == ""
+ else answer1[i]["instruction"] + " " + answer1[i]["input"]
+ )
+ answer1[i]["category"]
ans1 = answer1[i]["output"]
ans2 = answer2[i]["output"]
@@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
step_to_add = ref_step_template[language]
- for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)"
+ for_the_given_answer = (
+ "{metric} (1-5) (directly give the score for the given answer):"
+ if language == "en"
+ else "{metric} (1-5) (直接对给定答案打分)"
+ )
# adjective is used to describe the word "answer" in the prompt.
adjective = "example" if language == "en" else "示例"
@@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
answer_to_add = ref_answer_template_correctness[language]
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
- step_to_add = step_to_add.format(metric=metric.lower(),
- adjective=adjective) + for_the_given_answer.format(metric=metric)
+ step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
+ metric=metric
+ )
return answer_to_add + step_to_add
@@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
for j in range(i):
messages_to_send.append(fill_in_message("user", user_messages[j]))
messages_to_send.append(
- fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]))
+ fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
+ )
# Length of user messages == Length of assistant messages + 1
# Because we always expect the api to response
@@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
return assistant_responses[-1]
-def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
- inst: Dict[str, Any],
- metrics: List[str],
- language: str,
- reference: Dict[str, Any] = None,
- model: str = "gpt-3.5-turbo",
- max_tokens: int = 2048) -> Dict[str, Any]:
+def get_gpt_evaluation_without_logprobs(
+ prompt: Dict[str, Any],
+ inst: Dict[str, Any],
+ metrics: List[str],
+ language: str,
+ reference: Dict[str, Any] = None,
+ model: str = "gpt-3.5-turbo",
+ max_tokens: int = 2048,
+) -> Dict[str, Any]:
"""
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
@@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3
- question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
+ question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"]
inst["evaluation"] = {}
@@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
if prompt_reference:
# Do a 2-round conversation
- response = multiturn_chat_completion([prompt_1st_round, prompt_reference],
- model,
- max_tokens=max_tokens,
- turns=2)
+ response = multiturn_chat_completion(
+ [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
+ )
else:
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
@@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
return inst
-def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
- inst: Dict[str, Any],
- metrics: List[str],
- max_tokens: int = 2048) -> Dict[str, Any]:
+def get_gpt_evaluation_with_logprobs(
+ prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
+) -> Dict[str, Any]:
"""
Use completion model(text-davinci-003) to evaluate one model answer.
Only completion models can return log probabilities.
@@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3
- question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
+ question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"]
inst["evaluation"] = {}
@@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
return inst
-def evaluate(answers: List[Dict],
- prompt: Dict[str, Any],
- metrics: List[str],
- category: str,
- model: str,
- language: str,
- references: List[Dict] = None) -> List[Dict]:
+def evaluate(
+ answers: List[Dict],
+ prompt: Dict[str, Any],
+ metrics: List[str],
+ category: str,
+ model: str,
+ language: str,
+ references: List[Dict] = None,
+) -> List[Dict]:
"""
Use GPT models to evaluate model answers and save evaluation results.
@@ -529,21 +532,23 @@ def evaluate(answers: List[Dict],
if model == "text-davinci-003":
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
else:
- future = executor.submit(get_gpt_evaluation_without_logprobs,
- prompt,
- inst,
- metrics,
- language,
- reference=None if references is None else references[idx],
- model=model,
- max_tokens=1)
+ future = executor.submit(
+ get_gpt_evaluation_without_logprobs,
+ prompt,
+ inst,
+ metrics,
+ language,
+ reference=None if references is None else references[idx],
+ model=model,
+ max_tokens=1,
+ )
futures.append(future)
for future in tqdm.tqdm(
- concurrent.futures.as_completed(futures),
- desc=f"{category}: ",
- total=len(futures),
+ concurrent.futures.as_completed(futures),
+ desc=f"{category}: ",
+ total=len(futures),
):
evaluations.append(future.result())
@@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) ->
return int(results[0])
else:
raise Exception(f"Invalid score pair. Got {evaluation}.")
- except Exception as e:
+ except Exception:
return 0
-def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any],
- save_path: str) -> Dict[str, Any]:
+def save_gpt_evaluation_results(
+ model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
+) -> Dict[str, Any]:
"""
Save evaluation results for different categories for one model.
@@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
scores[metric].append(0)
elif evaluation["evaluation"][metric]["logprobs"] is not None:
scores[metric].append(
- calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]))
+ calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
+ )
else:
scores[metric].append(
- calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation))
+ calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
+ )
statistics = {}
for metric in metrics:
@@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
for category in tqdm.tqdm(
- frame_per_category.keys(),
- desc=f"GPT evaluation: ",
- total=len(frame_per_category.keys()),
+ frame_per_category.keys(),
+ desc=f"GPT evaluation: ",
+ total=len(frame_per_category.keys()),
):
data = pd.DataFrame(frame_per_category[category])
diff --git a/applications/Chat/evaluate/metrics.py b/applications/Chat/evaluate/metrics.py
index 77f9b6e98044..85ee4de53725 100644
--- a/applications/Chat/evaluate/metrics.py
+++ b/applications/Chat/evaluate/metrics.py
@@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
"""
bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
cumulative_bleu = [0] * 4
- weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.),
- (1. / 4., 1. / 4., 1. / 4., 1. / 4.)]
+ weights = [
+ (1.0 / 1.0, 0.0, 0.0, 0.0),
+ (1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0),
+ (1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0),
+ (1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0),
+ ]
for pred, target in zip(preds, targets):
if language == "cn":
- pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
- target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()]
+ pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
+ target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()]
elif language == "en":
pred_list = preprocessing_text(pred).split()
target_list = [preprocessing_text(target).split()]
@@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
- """Calculate CHRF Score Metric in sentence level.
- """
+ """Calculate CHRF Score Metric in sentence level."""
chrf_score = {"chrf": 0}
cumulative_chrf = []
for pred, target in zip(preds, targets):
if language == "cn":
- pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
- target_list = ' '.join(jieba.cut(preprocessing_text(target))).split()
+ pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
+ target_list = " ".join(jieba.cut(preprocessing_text(target))).split()
elif language == "en":
pred_list = preprocessing_text(pred).split()
target_list = preprocessing_text(target).split()
@@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
all_targets = []
for pred, target in zip(preds, targets):
- pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred))))
- target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target))))
+ pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred))))
+ target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target))))
all_preds.append(pred_list)
all_targets.append(target_list)
@@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
longest common subsequence (LCS) between preds and targets.
"""
rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
- all_preds = []
- all_targets = []
rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
for pred, target in zip(preds, targets):
score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target))
- rouge_scores["rouge1"] += score['rouge1'].fmeasure
- rouge_scores["rouge2"] += score['rouge2'].fmeasure
- rouge_scores["rougeL"] += score['rougeL'].fmeasure
+ rouge_scores["rouge1"] += score["rouge1"].fmeasure
+ rouge_scores["rouge2"] += score["rouge2"].fmeasure
+ rouge_scores["rougeL"] += score["rougeL"].fmeasure
rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds)
rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds)
@@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
for pred in preds:
if language == "cn":
- pred_seg_list = ' '.join(jieba.cut(pred)).split()
+ pred_seg_list = " ".join(jieba.cut(pred)).split()
count_segs = len(pred_seg_list)
unique_segs = set(pred_seg_list)
count_unique_chars = len(unique_segs)
@@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
split_pred = preprocessing_text(pred).split()
for n in range(0, 3):
for i in range(0, len(split_pred) - n):
- ngram = ' '.join(split_pred[i:i + n + 1])
+ ngram = " ".join(split_pred[i : i + n + 1])
unique_ngram[n].add(ngram)
all_ngram_count[n] += 1
@@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language
for pred, target in zip(preds, targets):
if language == "cn":
- pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()]
- target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()]
+ pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()]
+ target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()]
elif language == "en":
pred_list = [char for char in preprocessing_text(pred).split()]
target_list = [char for char in preprocessing_text(target).split()]
diff --git a/applications/Chat/evaluate/unieval/__init__.py b/applications/Chat/evaluate/unieval/__init__.py
index dad8d6ad09fa..6ffccdaa0819 100644
--- a/applications/Chat/evaluate/unieval/__init__.py
+++ b/applications/Chat/evaluate/unieval/__init__.py
@@ -7,6 +7,9 @@
)
__all__ = [
- 'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
- 'analyze_unieval_results'
+ "get_evaluator",
+ "convert_data_to_unieval_format",
+ "calculate_average_score",
+ "save_unieval_results",
+ "analyze_unieval_results",
]
diff --git a/applications/Chat/evaluate/unieval/evaluator.py b/applications/Chat/evaluate/unieval/evaluator.py
index 56cc6d2f9e41..bf2bc33a95c0 100644
--- a/applications/Chat/evaluate/unieval/evaluator.py
+++ b/applications/Chat/evaluate/unieval/evaluator.py
@@ -28,29 +28,29 @@
class SumEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for text summarization """
+ def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
+ """Set up evaluator for text summarization"""
self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
+ model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
- cache_dir=cache_dir)
- self.task = 'summarization'
- self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance']
+ cache_dir=cache_dir,
+ )
+ self.task = "summarization"
+ self.dimensions = ["coherence", "consistency", "fluency", "relevance"]
def evaluate(self, data, category, dims=None, overall=True):
"""
- Get the scores of all the given dimensions
+ Get the scores of all the given dimensions
- category: The category to be evaluated.
+ category: The category to be evaluated.
- dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
- four dimensions: coherence, consistency, fluency, relevance.
+ dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
+ four dimensions: coherence, consistency, fluency, relevance.
- overall: indicates whether the overall score is to be calculated.
- Overall score can be customized to a combination of scores based on different
- dimensions. The default here is the average score of all the given dimensions.
+ overall: indicates whether the overall score is to be calculated.
+ Overall score can be customized to a combination of scores based on different
+ dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
@@ -63,12 +63,12 @@ def evaluate(self, data, category, dims=None, overall=True):
for dim in eval_dims:
# Calculate average sentence-level scores for 'consistency' and 'fluency'
- if dim == 'consistency' or dim == 'fluency':
+ if dim == "consistency" or dim == "fluency":
src_list, output_list = [], []
- n_sents = [] # the number of sentences in each generated summary
+ n_sents = [] # the number of sentences in each generated summary
for i in range(n_data):
- source = data[i]['source']
- system_outputs = sent_tokenize(data[i]['system_output'])
+ source = data[i]["source"]
+ system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
@@ -81,24 +81,26 @@ def evaluate(self, data, category, dims=None, overall=True):
score = []
for cur_n_sent in n_sents:
# prevent denominator from being 0
- score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
+ score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
start_idx += cur_n_sent
# Calculate summary-level score for 'coherence' and 'relevance'
- elif dim == 'coherence' or dim == 'relevance':
+ elif dim == "coherence" or dim == "relevance":
src_list, output_list, ref_list = [], [], []
for i in range(n_data):
- src_list.append(data[i]['source'])
- output_list.append(data[i]['system_output'])
- if dim == 'relevance':
- ref_list.append(data[i]['reference'])
+ src_list.append(data[i]["source"])
+ output_list.append(data[i]["system_output"])
+ if dim == "relevance":
+ ref_list.append(data[i]["reference"])
input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization
else:
- raise NotImplementedError('The input format for this dimension is still undefined. \
- Please customize it first.')
+ raise NotImplementedError(
+ "The input format for this dimension is still undefined. \
+ Please customize it first."
+ )
for i in range(n_data):
eval_scores[i][dim] = score[i]
@@ -106,35 +108,35 @@ def evaluate(self, data, category, dims=None, overall=True):
# Customize your overall score here.
if overall == True:
for i in range(n_data):
- eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
+ eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores
class DialogEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for dialogues """
+ def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
+ """Set up evaluator for dialogues"""
self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path,
+ model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
- cache_dir=cache_dir)
- self.task = 'dialogue'
- self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability']
+ cache_dir=cache_dir,
+ )
+ self.task = "dialogue"
+ self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"]
def evaluate(self, data, category, dims=None, overall=True):
"""
- Get the scores of all the given dimensions
+ Get the scores of all the given dimensions
- category: The category to be evaluated.
+ category: The category to be evaluated.
- dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
- five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
+ dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
+ five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
- overall: indicates whether the overall score is to be calculated.
- Overall score can be customized to a combination of scores based on different
- dimensions. The default here is the average score of all the given dimensions.
+ overall: indicates whether the overall score is to be calculated.
+ Overall score can be customized to a combination of scores based on different
+ dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
@@ -147,50 +149,48 @@ def evaluate(self, data, category, dims=None, overall=True):
for dim in eval_dims:
# Calculate summation score for 'engagingness'
- if dim == 'engagingness':
+ if dim == "engagingness":
src_list, output_list, context_list = [], [], []
- n_sents = [] # the number of sentences in each generated response
+ n_sents = [] # the number of sentences in each generated response
for i in range(n_data):
- source = data[i]['source']
- context = data[i]['context']
- system_outputs = sent_tokenize(data[i]['system_output'])
+ source = data[i]["source"]
+ context = data[i]["context"]
+ system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
context_list.append(context)
output_list.append(system_outputs[j])
- input_list = add_question(dimension=dim,
- output=output_list,
- src=src_list,
- context=context_list,
- task=self.task)
+ input_list = add_question(
+ dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
+ )
sent_score = self.scorer.score(input_list, self.task, category, dim)
# Get the summation score for each sample
start_idx = 0
score = []
for cur_n_sent in n_sents:
- score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]))
+ score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]))
start_idx += cur_n_sent
# Calculate turn-level score for other dimensions
- elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']:
+ elif dim in ["naturalness", "coherence", "groundedness", "understandability"]:
src_list, output_list, context_list = [], [], []
for i in range(n_data):
- src_list.append(data[i]['source'])
- output_list.append(data[i]['system_output'])
- context_list.append(data[i]['context'])
- input_list = add_question(dimension=dim,
- output=output_list,
- src=src_list,
- context=context_list,
- task=self.task)
+ src_list.append(data[i]["source"])
+ output_list.append(data[i]["system_output"])
+ context_list.append(data[i]["context"])
+ input_list = add_question(
+ dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
+ )
score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization
else:
- raise NotImplementedError('The input format for this dimension is still undefined. \
- Please customize it first.')
+ raise NotImplementedError(
+ "The input format for this dimension is still undefined. \
+ Please customize it first."
+ )
for i in range(n_data):
eval_scores[i][dim] = score[i]
@@ -198,35 +198,35 @@ def evaluate(self, data, category, dims=None, overall=True):
# Customize your overall score here.
if overall == True:
for i in range(n_data):
- eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
+ eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores
class D2tEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for data-to-text """
+ def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
+ """Set up evaluator for data-to-text"""
self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
+ model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
- cache_dir=cache_dir)
- self.task = 'data2text'
- self.dimensions = ['naturalness', 'informativeness']
+ cache_dir=cache_dir,
+ )
+ self.task = "data2text"
+ self.dimensions = ["naturalness", "informativeness"]
def evaluate(self, data, category, dims=None, overall=True):
"""
- Get the scores of all the given dimensions
+ Get the scores of all the given dimensions
- category: The category to be evaluated.
+ category: The category to be evaluated.
- dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
- two dimensions: naturalness and informativeness.
+ dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
+ two dimensions: naturalness and informativeness.
- overall: indicates whether the overall score is to be calculated.
- Overall score can be customized to a combination of scores based on different
- dimensions. The default here is the average score of all the given dimensions.
+ overall: indicates whether the overall score is to be calculated.
+ Overall score can be customized to a combination of scores based on different
+ dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
@@ -240,8 +240,8 @@ def evaluate(self, data, category, dims=None, overall=True):
for dim in eval_dims:
output_list, ref_list = [], []
for i in range(n_data):
- output_list.append(data[i]['system_output'])
- ref_list.append(data[i]['reference'])
+ output_list.append(data[i]["system_output"])
+ ref_list.append(data[i]["reference"])
input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim)
@@ -252,38 +252,38 @@ def evaluate(self, data, category, dims=None, overall=True):
# Customize your overall score here.
if overall == True:
for i in range(n_data):
- eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
+ eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores
class FactEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up evaluator for factual consistency detection """
+ def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
+ """Set up evaluator for factual consistency detection"""
self.scorer = UniEvaluator(
- model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path,
+ model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
- cache_dir=cache_dir)
- self.task = 'fact'
- self.dim = 'consistency'
+ cache_dir=cache_dir,
+ )
+ self.task = "fact"
+ self.dim = "consistency"
def evaluate(self, data, category):
"""
- Get the factual consistency score (only 1 dimension for this task)
+ Get the factual consistency score (only 1 dimension for this task)
- category: The category to be evaluated.
+ category: The category to be evaluated.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
# Calculate average sentence-level scores for factual consistency
src_list, output_list = [], []
- n_sents = [] # the number of sentences in the claim
+ n_sents = [] # the number of sentences in the claim
for i in range(n_data):
- source = data[i]['source']
- system_outputs = sent_tokenize(data[i]['system_output'])
+ source = data[i]["source"]
+ system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
@@ -295,7 +295,7 @@ def evaluate(self, data, category):
start_idx = 0
score = []
for cur_n_sent in n_sents:
- score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
+ score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent)
start_idx += cur_n_sent
for i in range(n_data):
@@ -304,28 +304,26 @@ def evaluate(self, data, category):
return eval_scores
-def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None):
- assert task in ['summarization', 'dialogue', 'data2text', 'fact']
- if task == 'summarization':
- return SumEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- elif task == 'dialogue':
- return DialogEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- elif task == 'data2text':
- return D2tEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
- elif task == 'fact':
- return FactEvaluator(model_name_or_path=model_name_or_path,
- max_length=max_length,
- device=device,
- cache_dir=cache_dir)
+def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None):
+ assert task in ["summarization", "dialogue", "data2text", "fact"]
+ if task == "summarization":
+ return SumEvaluator(
+ model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
+ )
+ elif task == "dialogue":
+ return DialogEvaluator(
+ model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
+ )
+ elif task == "data2text":
+ return D2tEvaluator(
+ model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
+ )
+ elif task == "fact":
+ return FactEvaluator(
+ model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
+ )
else:
- raise NotImplementedError('Other tasks are not implemented, \
- please customize specific tasks here.')
+ raise NotImplementedError(
+ "Other tasks are not implemented, \
+ please customize specific tasks here."
+ )
diff --git a/applications/Chat/evaluate/unieval/scorer.py b/applications/Chat/evaluate/unieval/scorer.py
index 2c70bb9f6ded..45706b833205 100644
--- a/applications/Chat/evaluate/unieval/scorer.py
+++ b/applications/Chat/evaluate/unieval/scorer.py
@@ -27,9 +27,8 @@
class UniEvaluator:
-
- def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
- """ Set up model """
+ def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
+ """Set up model"""
self.device = device
self.max_length = max_length
@@ -47,8 +46,8 @@ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_d
def score(self, inputs, task, category, dim, batch_size=8):
"""
- Get scores for the given samples.
- final_score = postive_score / (postive_score + negative_score)
+ Get scores for the given samples.
+ final_score = postive_score / (postive_score + negative_score)
"""
# The implementation of "forward" in T5 still requires decoder_input_ids.
@@ -58,31 +57,27 @@ def score(self, inputs, task, category, dim, batch_size=8):
pos_score_list, neg_score_list = [], []
for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
- src_list = inputs[i:i + batch_size]
- tgt_list = tgts[i:i + batch_size]
+ src_list = inputs[i : i + batch_size]
+ tgt_list = tgts[i : i + batch_size]
try:
with torch.no_grad():
- encoded_src = self.tokenizer(src_list,
- max_length=self.max_length,
- truncation=True,
- padding=True,
- return_tensors='pt')
- encoded_tgt = self.tokenizer(tgt_list,
- max_length=self.max_length,
- truncation=True,
- padding=True,
- return_tensors='pt')
-
- src_tokens = encoded_src['input_ids'].to(self.device)
- src_mask = encoded_src['attention_mask'].to(self.device)
-
- tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
+ encoded_src = self.tokenizer(
+ src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
+ )
+ encoded_tgt = self.tokenizer(
+ tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
+ )
+
+ src_tokens = encoded_src["input_ids"].to(self.device)
+ src_mask = encoded_src["attention_mask"].to(self.device)
+
+ tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1)
output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
logits = output.logits.view(-1, self.model.config.vocab_size)
- pos_score = self.softmax(logits)[:, self.pos_id] # Yes
- neg_score = self.softmax(logits)[:, self.neg_id] # No
+ pos_score = self.softmax(logits)[:, self.pos_id] # Yes
+ neg_score = self.softmax(logits)[:, self.neg_id] # No
cur_pos_score = [x.item() for x in pos_score]
cur_neg_score = [x.item() for x in neg_score]
@@ -90,8 +85,8 @@ def score(self, inputs, task, category, dim, batch_size=8):
neg_score_list += cur_neg_score
except RuntimeError:
- print(f'source: {src_list}')
- print(f'target: {tgt_list}')
+ print(f"source: {src_list}")
+ print(f"target: {tgt_list}")
exit(0)
score_list = []
diff --git a/applications/Chat/evaluate/unieval/utils.py b/applications/Chat/evaluate/unieval/utils.py
index a381e9e590b2..46b0f2907a30 100644
--- a/applications/Chat/evaluate/unieval/utils.py
+++ b/applications/Chat/evaluate/unieval/utils.py
@@ -31,105 +31,142 @@
def add_question(dimension, output, src=None, ref=None, context=None, task=None):
"""
- Add questions to generate input in Bool-QA format for UniEval.
-
- dimension: specific dimension to be evaluated
- src: source input for different NLG tasks. For example, source document for summarization
- and dialogue history for dialogue response generation.
- output: output text generated by the models
- ref: human-annotated groundtruth
- context: the context needed to evaluate several specific dimension. For example,
- additional factual information when evaluating engagingness and groundedness in dialogues.
+ Add questions to generate input in Bool-QA format for UniEval.
+
+ dimension: specific dimension to be evaluated
+ src: source input for different NLG tasks. For example, source document for summarization
+ and dialogue history for dialogue response generation.
+ output: output text generated by the models
+ ref: human-annotated groundtruth
+ context: the context needed to evaluate several specific dimension. For example,
+ additional factual information when evaluating engagingness and groundedness in dialogues.
"""
input_with_question = []
for i in range(len(output)):
# For summarization
- if task == 'summarization':
- if dimension == 'fluency':
- cur_input = 'question: Is this a fluent paragraph? paragraph: ' + output[i]
- elif dimension == 'coherence':
- cur_input = 'question: Is this a coherent summary to the document? summary: ' + output[
- i] + ' document: ' + src[i]
- elif dimension == 'consistency':
- cur_input = 'question: Is this claim consistent with the document? claim: ' + output[
- i] + ' document: ' + src[i]
- elif dimension == 'relevance':
- cur_input = 'question: Is this summary relevant to the reference? summary: ' + output[
- i] + ' reference: ' + ref[i]
+ if task == "summarization":
+ if dimension == "fluency":
+ cur_input = "question: Is this a fluent paragraph? paragraph: " + output[i]
+ elif dimension == "coherence":
+ cur_input = (
+ "question: Is this a coherent summary to the document? summary: "
+ + output[i]
+ + " document: "
+ + src[i]
+ )
+ elif dimension == "consistency":
+ cur_input = (
+ "question: Is this claim consistent with the document? claim: "
+ + output[i]
+ + " document: "
+ + src[i]
+ )
+ elif dimension == "relevance":
+ cur_input = (
+ "question: Is this summary relevant to the reference? summary: "
+ + output[i]
+ + " reference: "
+ + ref[i]
+ )
else:
raise NotImplementedError(
- 'The input format for this dimension is still undefined. Please customize it first.')
+ "The input format for this dimension is still undefined. Please customize it first."
+ )
# For dialogues
- elif task == 'dialogue':
- if dimension == 'naturalness':
- cur_input = 'question: Is this a natural response in the dialogue? response: ' + output[i]
- elif dimension == 'coherence':
- cur_input = 'question: Is this a coherent response given the dialogue history? response: '\
- + output[i] + ' dialogue history: ' + src[i]
- elif dimension == 'engagingness':
- cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? response: '\
- + output[i] + ' dialogue history: ' + src[i] + ' fact: ' + context[i]
- elif dimension == 'groundedness':
- cur_input = 'question: Is this response consistent with knowledge in the fact? response: '\
- + output[i] + ' fact: ' + context[i]
- elif dimension == 'understandability':
- cur_input = 'question: Is this an understandable response in the dialogue? response: ' + output[i]
+ elif task == "dialogue":
+ if dimension == "naturalness":
+ cur_input = "question: Is this a natural response in the dialogue? response: " + output[i]
+ elif dimension == "coherence":
+ cur_input = (
+ "question: Is this a coherent response given the dialogue history? response: "
+ + output[i]
+ + " dialogue history: "
+ + src[i]
+ )
+ elif dimension == "engagingness":
+ cur_input = (
+ "question: Is this an engaging and informative response according to the dialogue history and fact? response: "
+ + output[i]
+ + " dialogue history: "
+ + src[i]
+ + " fact: "
+ + context[i]
+ )
+ elif dimension == "groundedness":
+ cur_input = (
+ "question: Is this response consistent with knowledge in the fact? response: "
+ + output[i]
+ + " fact: "
+ + context[i]
+ )
+ elif dimension == "understandability":
+ cur_input = "question: Is this an understandable response in the dialogue? response: " + output[i]
else:
raise NotImplementedError(
- 'The input format for this dimension is still undefined. Please customize it first.')
+ "The input format for this dimension is still undefined. Please customize it first."
+ )
# For data-to-text
- elif task == 'data2text':
- if dimension == 'naturalness':
- cur_input = 'question: Is this a fluent utterance? utterance: ' + output[i]
- elif dimension == 'informativeness':
- cur_input = 'question: Is this sentence informative according to the reference? sentence: '\
- + output[i] + ' reference: ' + ref[i]
+ elif task == "data2text":
+ if dimension == "naturalness":
+ cur_input = "question: Is this a fluent utterance? utterance: " + output[i]
+ elif dimension == "informativeness":
+ cur_input = (
+ "question: Is this sentence informative according to the reference? sentence: "
+ + output[i]
+ + " reference: "
+ + ref[i]
+ )
else:
raise NotImplementedError(
- 'The input format for this dimension is still undefined. Please customize it first.')
+ "The input format for this dimension is still undefined. Please customize it first."
+ )
# For factual consistency detection
- elif task == 'fact':
- if dimension == 'consistency':
- cur_input = 'question: Is this claim consistent with the document? claim: ' + output[
- i] + ' document: ' + src[i]
+ elif task == "fact":
+ if dimension == "consistency":
+ cur_input = (
+ "question: Is this claim consistent with the document? claim: "
+ + output[i]
+ + " document: "
+ + src[i]
+ )
else:
- raise NotImplementedError('No other dimensions for the factual consistency detection task.')
+ raise NotImplementedError("No other dimensions for the factual consistency detection task.")
# For new customized tasks
else:
- raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.')
+ raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.")
input_with_question.append(cur_input)
return input_with_question
def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
"""
- Convert the data into the unieval's format.
+ Convert the data into the unieval's format.
- output_list: a list of model output
+ output_list: a list of model output
- src_list: source input for different NLG tasks. For example, source document for summarization
- and dialogue history for dialogue response generation
- ref_list: human-annotated groundtruth
+ src_list: source input for different NLG tasks. For example, source document for summarization
+ and dialogue history for dialogue response generation
+ ref_list: human-annotated groundtruth
"""
json_data = []
for i in range(len(output_list)):
cur = {}
- cur['system_output'] = output_list[i]
+ cur["system_output"] = output_list[i]
if src_list is not None:
- cur['source'] = src_list[i]
+ cur["source"] = src_list[i]
if ref_list is not None:
- cur['reference'] = ref_list[i]
- cur['context'] = ""
+ cur["reference"] = ref_list[i]
+ cur["context"] = ""
json_data.append(cur)
return json_data
def calculate_average_score(scores):
"""
- Calculate average scores for different metrics
+ Calculate average scores for different metrics
- scores: a list of scores for different metrics for each answer
+ scores: a list of scores for different metrics for each answer
"""
metrics = {metric: 0 for metric in scores[0]}
@@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None:
frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
for metric in tqdm.tqdm(
- frame_per_metric.keys(),
- desc=f"UniEval metrics: ",
- total=len(frame_per_metric.keys()),
+ frame_per_metric.keys(),
+ desc=f"UniEval metrics: ",
+ total=len(frame_per_metric.keys()),
):
data = pd.DataFrame(frame_per_metric[metric])
diff --git a/applications/Chat/evaluate/utils.py b/applications/Chat/evaluate/utils.py
index 406e43db99aa..10df455b69d7 100644
--- a/applications/Chat/evaluate/utils.py
+++ b/applications/Chat/evaluate/utils.py
@@ -1,7 +1,6 @@
import io
import json
import os
-import re
import string
from typing import Dict
@@ -55,7 +54,7 @@ def jload(f, mode="r"):
def get_json_list(file_path):
- with open(file_path, 'r') as f:
+ with open(file_path, "r") as f:
json_list = []
for line in f:
json_list.append(json.loads(line))
@@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None:
frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
for metric in tqdm.tqdm(
- frame_per_metric.keys(),
- desc=f"automatic metrics: ",
- total=len(frame_per_metric.keys()),
+ frame_per_metric.keys(),
+ desc=f"automatic metrics: ",
+ total=len(frame_per_metric.keys()),
):
data = pd.DataFrame(frame_per_metric[metric])
diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py
index 2fe293957079..d4b17689e9cb 100644
--- a/applications/Chat/examples/community/peft/easy_dataset.py
+++ b/applications/Chat/examples/community/peft/easy_dataset.py
@@ -3,7 +3,6 @@
from typing import Dict, Sequence
import torch
-from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
@@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
padding="longest",
max_length=max_length,
truncation=True,
- ) for text in strings
+ )
+ for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
@@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
class EasySupervisedDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
- #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
+ # split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
- sources.append(line[:sep_index + 3])
- targets.append(line[sep_index + 3:] + tokenizer.eos_token)
+ sources.append(line[: sep_index + 3])
+ targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else:
sources.append(line)
targets.append("" + tokenizer.eos_token)
@@ -83,15 +82,17 @@ def __str__(self):
class EasyPromptsDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
- all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
+ all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [
- tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
- truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
+ tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
+ "input_ids"
+ ]
+ .to(torch.cuda.current_device())
+ .squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
@@ -110,7 +111,6 @@ def __str__(self):
class EasyRewardDataset(Dataset):
-
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
@@ -120,44 +120,42 @@ def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None
else:
self.end_token = special_token
print(self.end_token)
- #read all lines in the train_file to a list
+ # read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
- prompt = "提问:" + data['prompt'] + " 回答:"
-
- chosen = prompt + data['chosen'] + self.end_token
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen.append({
- "input_ids": chosen_token['input_ids'],
- "attention_mask": chosen_token['attention_mask']
- })
-
- reject = prompt + data['rejected'] + self.end_token
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject.append({
- "input_ids": reject_token['input_ids'],
- "attention_mask": reject_token['attention_mask']
- })
+ prompt = "提问:" + data["prompt"] + " 回答:"
+
+ chosen = prompt + data["chosen"] + self.end_token
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen.append(
+ {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+ )
+
+ reject = prompt + data["rejected"] + self.end_token
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject.append(
+ {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
+ )
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
- return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
- "input_ids"], self.reject[idx]["attention_mask"]
-
- #python representation of the object and the string representation of the object
+ return (
+ self.chosen[idx]["input_ids"],
+ self.chosen[idx]["attention_mask"],
+ self.reject[idx]["input_ids"],
+ self.reject[idx]["attention_mask"],
+ )
+
+ # python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
@@ -165,26 +163,25 @@ def __str__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
-'''
+"""
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
-'''
+"""
class EasySFTDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
- #read the data_file line by line
+ # read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f:
- #encode the text data line by line and put raw python list input_ids only to raw_input_ids list
+ # encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
- #if the encoded_ids is longer than max_length, then split it into several parts
+ # if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length):
- raw_input_ids.append(encoded_ids[i:i + max_length])
+ raw_input_ids.append(encoded_ids[i : i + max_length])
else:
raw_input_ids.append(encoded_ids)
@@ -196,12 +193,13 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
- #pad the current_input_ids to max_length with tokenizer.pad_token_id
+ # pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
current_input_ids = []
else:
current_input_ids.extend(input_ids)
@@ -210,14 +208,16 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
else:
- #just append the raw_input_ids to max_length
+ # just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids)
@@ -227,14 +227,14 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_
def __len__(self):
return len(self.input_ids)
- #get item from dataset
+ # get item from dataset
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
- #generate the dataset description to be printed by print in python
+ # generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
- #generate the dataset description to be printed by print in python
+ # generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py
index fe294868159d..db629e50ed94 100644
--- a/applications/Chat/examples/community/peft/easy_models.py
+++ b/applications/Chat/examples/community/peft/easy_models.py
@@ -4,7 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
from coati.models.generation import generate
-from coati.models.utils import log_probs_from_logits, masked_mean
+from coati.models.utils import log_probs_from_logits
from peft import PeftModel
from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM
@@ -24,38 +24,33 @@ def __init__(self, model: nn.Module) -> None:
@torch.no_grad()
def generate(
- self,
- input_ids: torch.Tensor,
- return_action_mask: bool = True,
- **kwargs
+ self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
- pad_token_id = kwargs.get('pad_token_id', None)
+ pad_token_id = kwargs.get("pad_token_id", None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
- eos_token_id = kwargs.get('eos_token_id', None)
+ eos_token_id = kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
- action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
- return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
+ return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
- def forward(self,
- sequences: torch.LongTensor,
- num_actions: int,
- attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- """Returns action log probs
- """
+ def forward(
+ self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Returns action log probs"""
output = self.model(sequences, attention_mask=attention_mask)
- logits = output['logits']
+ logits = output["logits"]
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
@@ -75,11 +70,13 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_path: str = None) -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_path: str = None,
+ ) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py
index 9385e457d852..e49db1d2bc1b 100644
--- a/applications/Chat/examples/community/peft/train_peft_prompts.py
+++ b/applications/Chat/examples/community/peft/train_peft_prompts.py
@@ -1,18 +1,16 @@
import argparse
-import pandas as pd
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
+from coati.dataset import DataCollatorForSupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic
-from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
-from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from coati.models.gpt import GPTRM, GPTCritic
+from coati.models.llama import LlamaCritic, LlamaRM
+from coati.models.opt import OPTRM, OPTCritic
from coati.trainer import PPOTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
-from peft import PeftModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
@@ -23,24 +21,24 @@
def main(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
- state_dict = torch.load(args.rm_path, map_location='cpu')
+ state_dict = torch.load(args.rm_path, map_location="cpu")
# configure model
- if args.model == 'bloom':
+ if args.model == "bloom":
# initial_model = BLOOMActor(pretrained=args.pretrain)
- print('Using peft lora to load Bloom model as initial_model')
+ print("Using peft lora to load Bloom model as initial_model")
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
- print('Using peft lora to load Bloom model as initial_model (Done)')
+ print("Using peft lora to load Bloom model as initial_model (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@@ -49,59 +47,59 @@ def main(args):
else:
rm_model_name = args.rm_model
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- print('Loading reward model from', args.rm_path)
+ print("Loading reward model from", args.rm_path)
reward_model.load_state_dict(state_dict)
- if args.strategy != 'colossalai_gemini':
+ if args.strategy != "colossalai_gemini":
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
- if args.model == 'bloom':
+ if args.model == "bloom":
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- print('Using peft lora to load Bloom model as Actor')
+ print("Using peft lora to load Bloom model as Actor")
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
- print('Using peft lora to load Bloom model as Actor (Done)')
+ print("Using peft lora to load Bloom model as Actor (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ")
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- print('Loading reward model from', args.rm_path)
+ print("Loading reward model from", args.rm_path)
critic.load_state_dict(state_dict)
del state_dict
- if args.strategy != 'colossalai_gemini':
+ if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
@@ -109,18 +107,18 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
- if args.model == 'gpt2':
+ if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -132,26 +130,27 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
- prompt_dataloader = DataLoader(prompt_dataset,
- shuffle=(prompt_sampler is None),
- sampler=prompt_sampler,
- batch_size=args.train_batch_size)
+ prompt_dataloader = DataLoader(
+ prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
+ )
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
- pretrain_dataloader = DataLoader(pretrain_dataset,
- shuffle=(pretrain_sampler is None),
- sampler=pretrain_sampler,
- batch_size=args.ptx_batch_size,
- collate_fn=data_collator)
+ pretrain_dataloader = DataLoader(
+ pretrain_dataset,
+ shuffle=(pretrain_sampler is None),
+ sampler=pretrain_sampler,
+ batch_size=args.ptx_batch_size,
+ collate_fn=data_collator,
+ )
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
@@ -178,45 +177,46 @@ def tokenize_fn(texts):
eos_token_id=tokenizer.eos_token_id,
)
- trainer.fit(prompt_dataloader=prompt_dataloader,
- pretrain_dataloader=pretrain_dataloader,
- num_episodes=args.num_episodes,
- num_update_steps=args.num_update_steps,
- num_collect_steps=args.num_collect_steps)
+ trainer.fit(
+ prompt_dataloader=prompt_dataloader,
+ pretrain_dataloader=pretrain_dataloader,
+ num_episodes=args.num_episodes,
+ num_update_steps=args.num_update_steps,
+ num_collect_steps=args.num_collect_steps,
+ )
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(actor_optim,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
- parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='ddp',
- help='strategy to use')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--sft_lora_path', type=str, default=None)
- parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--rm_path', type=str, default=None)
- parser.add_argument('--rm_pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--num_collect_steps', type=int, default=10)
- parser.add_argument('--num_update_steps', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=2)
- parser.add_argument('--ptx_batch_size', type=int, default=1)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--kl_coef', type=float, default=0.1)
- parser.add_argument('--ptx_coef', type=float, default=0.9)
+ parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
+ parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
+ parser.add_argument(
+ "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
+ )
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--sft_lora_path", type=str, default=None)
+ parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--rm_path", type=str, default=None)
+ parser.add_argument("--rm_pretrain", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--num_collect_steps", type=int, default=10)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--train_batch_size", type=int, default=2)
+ parser.add_argument("--ptx_batch_size", type=int, default=1)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.9)
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py
index 4af08e6d0141..0b62dd652adb 100644
--- a/applications/Chat/examples/community/peft/train_peft_sft.py
+++ b/applications/Chat/examples/community/peft/train_peft_sft.py
@@ -1,18 +1,10 @@
import argparse
import os
-import loralib as lora
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
-from coati.models.base import RewardModel
-from coati.models.bloom import BLOOMLM
-from coati.models.gpt import GPTLM
-from coati.models.llama import LlamaLM
-from coati.models.opt import OPTLM
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
-from datasets import load_dataset
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from torch.optim import Adam
@@ -29,75 +21,76 @@
def train(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
- print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
+ print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
- if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
- and os.path.exists(args.save_path + '/adapter_model.bin'):
+ if (
+ os.path.exists(args.save_path)
+ and os.path.exists(args.save_path + "/adapter_config.json")
+ and os.path.exists(args.save_path + "/adapter_model.bin")
+ ):
print("loading from saved peft model ", args.save_path)
model = PeftModel.from_pretrained(model, args.save_path)
else:
# we'll use peft lora library to do the lora
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
# config lora with rank of lora_rank
- lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
- inference_mode=False,
- r=lora_rank,
- lora_alpha=32,
- lora_dropout=0.1)
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
+ )
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = AutoTokenizer.from_pretrained(
args.pretrain,
padding_side="right",
use_fast=False,
)
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
- if args.model == 'llama' and args.strategy == 'colossalai_gemini':
+ if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
- sub_module_name = '.'.join(name.split('.')[:-1])
- weight_name = name.split('.')[-1]
+ sub_module_name = ".".join(name.split(".")[:-1])
+ weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
- logger.set_level('WARNING')
+ logger.set_level("WARNING")
# configure dataset
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
@@ -108,47 +101,57 @@ def train(args):
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
data_collator = default_collate
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
if eval_dataset is not None:
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=False,
- seed=42,
- drop_last=False,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=False,
+ seed=42,
+ drop_last=False,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ collate_fn=data_collator,
+ pin_memory=True,
+ )
if eval_dataset is not None:
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ eval_dataloader = DataLoader(
+ eval_dataset,
+ shuffle=(eval_sampler is None),
+ sampler=eval_sampler,
+ batch_size=args.batch_size,
+ collate_fn=data_collator,
+ pin_memory=True,
+ )
else:
eval_dataloader = None
- trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- batch_size=args.batch_size,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps)
+ trainer = SFTTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ batch_size=args.batch_size,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ )
trainer.fit(logger=logger, log_interval=args.log_interval)
@@ -156,29 +159,27 @@ def train(args):
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='ddp')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--dataset', type=str, default=None)
- parser.add_argument('--eval_dataset', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='output')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--max_epochs', type=int, default=3)
- parser.add_argument('--batch_size', type=int, default=4)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
- parser.add_argument('--lr', type=float, default=5e-6)
- parser.add_argument('--accumulation_steps', type=int, default=8)
- parser.add_argument('--enable_peft_lora', action='store_true', default=False)
- parser.add_argument("--is_short_text", action='store_true', default=False)
+ parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--dataset", type=str, default=None)
+ parser.add_argument("--eval_dataset", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--enable_peft_lora", action="store_true", default=False)
+ parser.add_argument("--is_short_text", action="store_true", default=False)
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/Chat/examples/community/ray/ray_job_script.py
index 53f304d379fe..e8a1175a9c32 100644
--- a/applications/Chat/examples/community/ray/ray_job_script.py
+++ b/applications/Chat/examples/community/ray/ray_job_script.py
@@ -6,16 +6,25 @@
def main(api_server_endpoint="http://127.0.0.1:8265"):
client = JobSubmissionClient(api_server_endpoint)
client.submit_job(
- entrypoint=
- "python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
+ entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
runtime_env={
- "working_dir":
- "applications/Chat",
+ "working_dir": "applications/Chat",
"pip": [
- "torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain",
- "tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat"
- ]
- })
+ "torch==1.13.1",
+ "transformers>=4.20.1",
+ "datasets",
+ "loralib",
+ "colossalai>=0.2.4",
+ "langchain",
+ "tokenizers",
+ "fastapi",
+ "sse_starlette",
+ "wandb",
+ "sentencepiece",
+ "gpustat",
+ ],
+ },
+ )
if __name__ == "__main__":
diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/Chat/examples/community/ray/train_prompts_on_ray.py
index 1bba9ad66fbc..8abd83a8b249 100644
--- a/applications/Chat/examples/community/ray/train_prompts_on_ray.py
+++ b/applications/Chat/examples/community/ray/train_prompts_on_ray.py
@@ -26,9 +26,14 @@
class ExperienceCompositionRefs:
-
- def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
- base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
+ def __init__(
+ self,
+ sequences_attention_mask_action_mask_ref: ray.ObjectRef,
+ action_log_probs_ref: ray.ObjectRef,
+ base_action_log_probs_ref: ray.ObjectRef,
+ value_ref: ray.ObjectRef,
+ r_ref: ray.ObjectRef,
+ ) -> None:
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
self.action_log_probs_ref = action_log_probs_ref
self.base_action_log_probs_ref = base_action_log_probs_ref
@@ -37,14 +42,14 @@ def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, acti
class ExperienceMaker:
-
def __init__(self, kl_coef) -> None:
self.kl_coef = kl_coef
@torch.no_grad()
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
sequences, attention_mask, action_mask = ray.get(
- experiment_computation_refs.sequences_attention_mask_action_mask_ref)
+ experiment_computation_refs.sequences_attention_mask_action_mask_ref
+ )
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
r = ray.get(experiment_computation_refs.r_ref)
@@ -58,11 +63,10 @@ def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs
class DistributedTorchRayActor:
-
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
- logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
- level=logging.INFO,
- datefmt='%Y-%m-%d %H:%M:%S')
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
+ )
self._model = None
self._world_size = world_size
self._rank = rank
@@ -82,7 +86,7 @@ def _get_current_node_ip():
@staticmethod
def _get_free_port():
with socket.socket() as sock:
- sock.bind(('', 0))
+ sock.bind(("", 0))
return sock.getsockname()[1]
def get_master_addr_port(self):
@@ -90,7 +94,6 @@ def get_master_addr_port(self):
class BasePPORole(DistributedTorchRayActor):
-
def add_experience_maker(self, kl_coef: float = 0.1):
self._experience_maker = ExperienceMaker(kl_coef)
@@ -99,12 +102,12 @@ def make_experience(self, experience_computation_ref: ExperienceCompositionRefs)
def _init_strategy(self, strategy: str):
# configure strategy
- if strategy == 'ddp':
+ if strategy == "ddp":
self._strategy = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif strategy == "colossalai_gemini":
+ self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ elif strategy == "colossalai_zero2":
+ self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
@@ -124,11 +127,9 @@ def _prepare_model_with_strategy(self, has_optimizer: bool):
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
raise NotImplementedError()
- def init_model_from_pretrained(self,
- strategy: str,
- model_class: Type[LoRAModule],
- pretrain: str,
- has_optimizer=False):
+ def init_model_from_pretrained(
+ self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False
+ ):
self._init_strategy(strategy)
self._load_model_from_pretrained(model_class, pretrain)
self._prepare_model_with_strategy(has_optimizer)
@@ -138,7 +139,6 @@ def eval(self):
class TrainablePPORole(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -161,38 +161,39 @@ def learn_on_experiences(self, experience_refs):
@ray.remote(num_gpus=1)
class RayPPOActor(TrainablePPORole):
-
def set_loss_function(self, eps_clip: float):
self._actor_loss_fn = PolicyLoss(eps_clip)
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
- if model_type == 'gpt2':
+ if model_type == "gpt2":
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
- elif model_type == 'bloom':
+ elif model_type == "bloom":
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
- elif model_type == 'opt':
+ elif model_type == "opt":
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
else:
raise ValueError(f'Unsupported model "{model_type}"')
# Set tokenize function for sequence generation
def _text_input_tokenize_fn(texts):
- batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
+ batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
self._sample_tokenize_function = _text_input_tokenize_fn
def setup_generate_kwargs(self, generate_kwargs: dict):
from coati.trainer.ppo import _set_default_generate_kwargs
+
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
- self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
- self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
+ self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id
+ self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
import pandas as pd
- prompts = pd.read_csv(prompt_url)['prompt']
+
+ prompts = pd.read_csv(prompt_url)["prompt"]
self._sampler = self._strategy.setup_sampler(prompts)
def _generate(self, input_ids, **generate_kwargs):
@@ -214,10 +215,9 @@ def calculate_action_log_probs(self, sequence_attention_action_mask):
def _training_step(self, experience):
num_actions = experience.action_mask.size(1)
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
- actor_loss = self._actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ actor_loss = self._actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
self._strategy.backward(actor_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -229,17 +229,18 @@ def save_checkpoint(self, save_path, should_save_optimizer: bool):
self._strategy.save_model(self._model, save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if should_save_optimizer:
- self._strategy.save_optimizer(self._optimizer,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ self._strategy.save_optimizer(
+ self._optimizer,
+ "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()),
+ only_rank0=False,
+ )
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
- encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
+ encoded_input = self._model_tokenizer(prompt, return_tensors="pt")
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
- sequence, _ = self._model.generate(**input_ids,
- max_length=max_length,
- return_action_mask=False,
- num_return_sequences=num_return_sequences)
+ sequence, _ = self._model.generate(
+ **input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences
+ )
token_list = list(sequence.data[0])
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
return output
@@ -247,18 +248,16 @@ def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
@ray.remote(num_gpus=1)
class RayPPOCritic(TrainablePPORole):
-
def set_loss_function(self, value_clip: float):
self._critic_loss_fn = ValueLoss(value_clip)
def _training_step(self, experience):
- values = self._model(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self._critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self._model(
+ experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
+ )
+ critic_loss = self._critic_loss_fn(
+ values, experience.values, experience.reward, action_mask=experience.action_mask
+ )
self._strategy.backward(critic_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -272,12 +271,12 @@ def calculate_value(self, sequence_attention_action_mask):
@ray.remote(num_gpus=1)
class RayPPORewardModel(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
- self._model = RewardModel(deepcopy(critic.model),
- deepcopy(critic.value_head)).to(torch.cuda.current_device())
+ self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(
+ torch.cuda.current_device()
+ )
@torch.no_grad()
def calculate_r(self, sequence_attention_action_mask):
@@ -287,7 +286,6 @@ def calculate_r(self, sequence_attention_action_mask):
@ray.remote(num_gpus=1)
class RayPPOInitialModel(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -300,8 +298,8 @@ def calculate_base_action_log_probs(self, sequence_attention_action_mask):
class PPORayActorGroup:
"""
- A group of ray actors
- Functions start with 'async' should return list of object refs
+ A group of ray actors
+ Functions start with 'async' should return list of object refs
"""
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
@@ -319,8 +317,9 @@ def _initiate_actors(self):
pg = placement_group(bundles, strategy="STRICT_SPREAD")
ray.get(pg.ready())
if pg:
- master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
- placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
+ master_actor = self.ray_actor_type.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)
+ ).remote(world_size, 0, 0, None, None)
else:
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
self._actor_handlers = [master_actor]
@@ -331,16 +330,20 @@ def _initiate_actors(self):
for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node
if pg:
- worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
- placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
- world_size, rank, local_rank, master_addr, master_port)
+ worker_actor = self.ray_actor_type.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
+ placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node
+ )
+ ).remote(world_size, rank, local_rank, master_addr, master_port)
else:
- worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
- master_addr, master_port)
+ worker_actor = self.ray_actor_type.options(num_gpus=1).remote(
+ world_size, rank, local_rank, master_addr, master_port
+ )
self._actor_handlers.append(worker_actor)
- def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
- has_optimizer: bool):
+ def async_init_model_from_pretrained(
+ self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool
+ ):
return [
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
for actor in self._actor_handlers
@@ -348,7 +351,6 @@ def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRA
class TrainableModelRayActorGroup(PPORayActorGroup):
-
def async_learn_on_experiences(self, experience_refs):
num_actors = len(self._actor_handlers)
learn_result_refs = []
@@ -359,7 +361,6 @@ def async_learn_on_experiences(self, experience_refs):
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
@@ -381,7 +382,8 @@ def async_calculate_action_log_probs(self, sequences_attention_mask_action_mask_
action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
action_log_probs_refs.append(action_log_probs_ref)
return action_log_probs_refs
@@ -393,7 +395,6 @@ def save_checkpoint(self, save_path, should_save_optimizer):
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
@@ -402,7 +403,8 @@ def async_calculate_value(self, sequences_attention_mask_action_mask_refs):
value_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
value_refs.append(value_ref)
return value_refs
@@ -411,7 +413,6 @@ def set_loss_function(self, value_clip: float = 0.4):
class PPOInitialRayActorGroup(PPORayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
@@ -420,13 +421,13 @@ def async_calculate_base_action_log_probs(self, sequences_attention_mask_action_
base_action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
base_action_log_probs_refs.append(base_action_log_probs_ref)
return base_action_log_probs_refs
class PPORewardRayActorGroup(PPORayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
@@ -435,20 +436,21 @@ def async_calculate_r(self, sequences_attention_mask_action_mask_refs):
r_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
r_refs.append(r_ref)
return r_refs
def main(args):
- logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
- level=logging.INFO,
- datefmt='%Y-%m-%d %H:%M:%S')
- if args.model == 'gpt2':
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
+ )
+ if args.model == "gpt2":
actor_model_class, critic_model_class = GPTActor, GPTCritic
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
- elif args.model == 'opt':
+ elif args.model == "opt":
actor_model_class, critic_model_class = OPTActor, OPTCritic
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -462,13 +464,14 @@ def main(args):
logging.info("Actors created")
# Prepare model for training
- generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
+ generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50}
ray.get(
- actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
- critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
- initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
- reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
- actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
+ actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True)
+ + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True)
+ + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False)
+ + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False)
+ + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)
+ )
logging.info("Models prepared for training")
# Prepare models for training
@@ -483,8 +486,12 @@ def main(args):
# Start training
logging.info("Training start")
# Set all models to eval and add experience maker
- all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
- initial_group._actor_handlers + reward_group._actor_handlers
+ all_ray_actors = (
+ actor_group._actor_handlers
+ + critic_group._actor_handlers
+ + initial_group._actor_handlers
+ + reward_group._actor_handlers
+ )
num_ray_actors = len(all_ray_actors)
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
@@ -497,18 +504,28 @@ def main(args):
time += 1
# Experience queueing stage
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
- experience_batch_size)
+ experience_batch_size
+ )
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
- sequences_attention_mask_action_mask_refs)
+ sequences_attention_mask_action_mask_refs
+ )
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
- sequences_attention_mask_action_mask_refs)
- experience_composition_refs.extend([
- ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
- base_action_log_probs_refs[i], values_refs[i], r_refs[i])
- for i in range(len(sequences_attention_mask_action_mask_refs))
- ])
+ sequences_attention_mask_action_mask_refs
+ )
+ experience_composition_refs.extend(
+ [
+ ExperienceCompositionRefs(
+ sequences_attention_mask_action_mask_refs[i],
+ action_log_probs_refs[i],
+ base_action_log_probs_refs[i],
+ values_refs[i],
+ r_refs[i],
+ )
+ for i in range(len(sequences_attention_mask_action_mask_refs))
+ ]
+ )
# Learning stage
if time % update_timesteps == 0:
experience_refs = []
@@ -519,8 +536,9 @@ def main(args):
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
# backward
ray.get(
- actor_group.async_learn_on_experiences(experience_refs) +
- critic_group.async_learn_on_experiences(experience_refs))
+ actor_group.async_learn_on_experiences(experience_refs)
+ + critic_group.async_learn_on_experiences(experience_refs)
+ )
# clear refs queue
experience_composition_refs.clear()
logging.info("Training finished")
@@ -528,26 +546,24 @@ def main(args):
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_csv_url', type=str)
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='ddp')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
- parser.add_argument('--pretrain', type=str, default='gpt2')
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
- parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
- parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
- parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
- parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
+ parser.add_argument("--prompt_csv_url", type=str)
+ parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"])
+ parser.add_argument("--pretrain", type=str, default="gpt2")
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--max_timesteps", type=int, default=10)
+ parser.add_argument("--update_timesteps", type=int, default=10)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1)
+ parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1)
+ parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1)
+ parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1)
+ parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1)
args = parser.parse_args()
ray.init()
main(args)
diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py
index c2b5f9a859a9..ec3482b5f789 100644
--- a/applications/Chat/examples/download_model.py
+++ b/applications/Chat/examples/download_model.py
@@ -22,7 +22,7 @@ def download(self, dir_path: str):
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
def download_all(self):
- file_path = snapshot_download(self.repo_id)
+ snapshot_download(self.repo_id)
def test_init(model: str, dir_path: str):
@@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str):
actor = GPTActor(config=config)
critic = GPTCritic(config=config)
reward_model = GPTRM(config=config)
- tokenizer = GPT2Tokenizer.from_pretrained(dir_path)
+ GPT2Tokenizer.from_pretrained(dir_path)
elif model == "bloom":
config = BloomConfig.from_pretrained(dir_path)
actor = BLOOMActor(config=config)
critic = BLOOMCritic(config=config)
reward_model = BLOOMRM(config=config)
- tokenizer = BloomTokenizerFast.from_pretrained(dir_path)
+ BloomTokenizerFast.from_pretrained(dir_path)
elif model == "opt":
config = AutoConfig.from_pretrained(dir_path)
actor = OPTActor(config=config)
critic = OPTCritic(config=config)
reward_model = OPTRM(config=config)
- tokenizer = AutoTokenizer.from_pretrained(dir_path)
+ AutoTokenizer.from_pretrained(dir_path)
else:
raise NotImplementedError(f"Model {model} not implemented")
@@ -59,17 +59,12 @@ def test_init(model: str, dir_path: str):
exit(0)
repo_list = {
- "gpt2": HFRepoFiles(
- repo_id="gpt2",
- files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
- ),
+ "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
"bloom": HFRepoFiles(
- repo_id="bigscience/bloom-560m",
- files=["config.json", "tokenizer.json", "tokenizer_config.json"]
+ repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
),
"opt": HFRepoFiles(
- repo_id="facebook/opt-350m",
- files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
+ repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
),
}
diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py
index 8d2fbba955b8..7e03b2d54260 100644
--- a/applications/Chat/examples/generate_conversation_dataset.py
+++ b/applications/Chat/examples/generate_conversation_dataset.py
@@ -31,9 +31,11 @@ def generate_alpaca():
def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
- dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
- data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
- split="train")
+ dataset = load_dataset(
+ "anon8231489123/ShareGPT_Vicuna_unfiltered",
+ data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
+ split="train",
+ )
conversations = dataset["conversations"]
@@ -43,23 +45,24 @@ def generate_sharegpt():
del conv["markdown"]
del conv["text"]
- conversation = dict(type="conversation",
- language="Multilingual",
- dataset="ShareGPT",
- conversations=conversations[idx])
+ conversation = dict(
+ type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
+ )
conversation_dataset.append(conversation)
return conversation_dataset
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--dataset',
- type=str,
- default="All",
- choices=["Alpaca", "ShareGPT", "All"],
- help="which dataset to convert, All will combine Alpaca and ShareGPT")
- parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="All",
+ choices=["Alpaca", "ShareGPT", "All"],
+ help="which dataset to convert, All will combine Alpaca and ShareGPT",
+ )
+ parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()
conversation_dataset = []
@@ -75,5 +78,5 @@ def generate_sharegpt():
for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1
- with open(args.save_path, mode='w') as f:
+ with open(args.save_path, mode="w") as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py
index 2abb31c09f82..4eec6feae505 100644
--- a/applications/Chat/examples/generate_prompt_dataset.py
+++ b/applications/Chat/examples/generate_prompt_dataset.py
@@ -6,7 +6,7 @@
def sample(args):
- with open(args.dataset_path, mode='r') as f:
+ with open(args.dataset_path, mode="r") as f:
dataset_list = json.load(f)
sampled_dataset = [
@@ -14,18 +14,14 @@ def sample(args):
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
]
- with open(args.save_path, mode='w') as f:
- json.dump(sampled_dataset, f, indent=4,
- default=str, ensure_ascii=False)
+ with open(args.save_path, mode="w") as f:
+ json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--dataset_path', type=str, default=None,
- required=True, help="path to the pretrain dataset")
- parser.add_argument('--save_path', type=str, default='prompt.json',
- help="path to save the prompt dataset")
- parser.add_argument('--sample_size', type=int,
- default=16384, help="size of the prompt dataset")
+ parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
+ parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
+ parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
args = parser.parse_args()
sample(args)
diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py
index e1e57e3cd376..087c49564e43 100644
--- a/applications/Chat/examples/inference.py
+++ b/applications/Chat/examples/inference.py
@@ -11,13 +11,13 @@
def eval(args):
# configure model
- if args.model == 'gpt2':
+ if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain)
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain)
- elif args.model == 'opt':
+ elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain)
- elif args.model == 'llama':
+ elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -28,45 +28,38 @@ def eval(args):
actor.load_state_dict(state_dict)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
+ elif args.model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.eval()
- input_ids = tokenizer.encode(args.input,
- return_tensors='pt')\
- .to(torch.cuda.current_device())
- outputs = generate(actor,
- input_ids,
- max_length=args.max_length,
- do_sample=True,
- top_k=50,
- top_p=0.95,
- num_return_sequences=1)
- output = tokenizer.batch_decode(outputs[0],
- skip_special_tokens=True)
+ input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
+ outputs = generate(
+ actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1
+ )
+ output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(f"[Output]: {''.join(output)}")
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--model_path', type=str, default=None)
- parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
- parser.add_argument('--max_length', type=int, default=100)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
+ parser.add_argument("--max_length", type=int, default=100)
args = parser.parse_args()
eval(args)
diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/Chat/examples/ray/1mmt_prompt.py
index 5dd52f1790e6..8de6219ec4e9 100644
--- a/applications/Chat/examples/ray/1mmt_prompt.py
+++ b/applications/Chat/examples/ray/1mmt_prompt.py
@@ -5,7 +5,6 @@
import pandas as pd
import ray
-import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
@@ -23,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -37,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
- 'local_rank': '0',
- 'rank': '0',
- 'world_size': '1',
- 'master_port': maker_port,
- 'master_addr': master_addr
+ "local_rank": "0",
+ "rank": "0",
+ "world_size": "1",
+ "master_port": maker_port,
+ "master_addr": master_addr,
}
# configure tokenizer
@@ -75,27 +77,33 @@ def trainer_model_fn():
eval_performance=True,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
- ) for i, env_info_trainer in enumerate(env_info_trainers)
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
]
def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
+ detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
@@ -130,12 +138,11 @@ def model_fn():
dataset_size = args.experience_batch_size * 4
def build_dataloader():
-
def tokenize_fn(texts):
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
- dataset = pd.read_csv(args.prompt_path)['prompt']
+ dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
@@ -144,32 +151,31 @@ def tokenize_fn(texts):
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_path', type=str, default=None)
- parser.add_argument('--num_trainers', type=int, default=1)
- parser.add_argument('--trainer_strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
- 'colossalai_zero2_cpu'
- ],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ parser.add_argument("--prompt_path", type=str, default=None)
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py
index 76929c9d0144..7c03a0468b02 100644
--- a/applications/Chat/examples/ray/mmmt_prompt.py
+++ b/applications/Chat/examples/ray/mmmt_prompt.py
@@ -5,7 +5,6 @@
import pandas as pd
import ray
-import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
@@ -23,13 +22,13 @@
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
+ s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
+ s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@@ -37,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
- env_info_trainers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_trainers),
- 'master_port': trainer_port,
- 'master_addr': master_addr
- } for rank in range(args.num_trainers)]
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
# maker_env_info
maker_port = str(get_free_port())
- env_info_makers = [{
- 'local_rank': '0',
- 'rank': str(rank),
- 'world_size': str(args.num_makers),
- 'master_port': maker_port,
- 'master_addr': master_addr
- } for rank in range(args.num_makers)]
+ env_info_makers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_makers),
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_makers)
+ ]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
@@ -63,13 +68,18 @@ def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
- if args.initial_model_quant_ckpt is not None and args.model == 'llama':
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
- initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
- args.quant_group_size).cuda().requires_grad_(False)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
@@ -78,7 +88,7 @@ def model_fn():
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
- f'trainer{x}'
+ f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
@@ -87,8 +97,8 @@ def model_fn():
kl_coef=0.1,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
- # sync_models_from_trainers=True,
- # generation kwargs:
+ # sync_models_from_trainers=True,
+ # generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
@@ -128,12 +138,11 @@ def trainer_model_fn():
dataset_size = args.experience_batch_size * 4
def build_dataloader():
-
def tokenize_fn(texts):
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
- dataset = pd.read_csv(args.prompt_path)['prompt']
+ dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
@@ -148,39 +157,44 @@ def tokenize_fn(texts):
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
- total_steps = args.experience_batch_size * args.experience_steps * \
- args.num_makers // (args.num_trainers * args.train_batch_size)
+ total_steps = (
+ args.experience_batch_size
+ * args.experience_steps
+ * args.num_makers
+ // (args.num_trainers * args.train_batch_size)
+ )
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_path', type=str, default=None)
- parser.add_argument('--num_makers', type=int, default=1)
- parser.add_argument('--num_trainers', type=int, default=1)
+ parser.add_argument("--prompt_path", type=str, default=None)
+ parser.add_argument("--num_makers", type=int, default=1)
+ parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
- '--trainer_strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'],
- default='ddp')
- parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--critic_pretrain', type=str, default=None)
- parser.add_argument('--experience_steps', type=int, default=4)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--train_epochs', type=int, default=1)
- parser.add_argument('--update_steps', type=int, default=2)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
- parser.add_argument('--quant_bits', type=int, default=4)
- parser.add_argument('--quant_group_size', type=int, default=128)
- parser.add_argument('--debug', action='store_true')
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt
index 5d0f9f927d17..d3ea7b0c8142 100644
--- a/applications/Chat/examples/requirements.txt
+++ b/applications/Chat/examples/requirements.txt
@@ -1,3 +1,3 @@
pandas>=1.4.1
sentencepiece
-colossalai==0.3.1
\ No newline at end of file
+colossalai==0.3.1
diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py
index d27a70a3fef6..ad688b07a7f2 100644
--- a/applications/Chat/examples/train_prompts.py
+++ b/applications/Chat/examples/train_prompts.py
@@ -20,28 +20,28 @@
def main(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
- warnings.warn('LoRA weights should be merged with the model weights')
- state_dict = torch.load(args.rm_path, map_location='cpu')
+ warnings.warn("LoRA weights should be merged with the model weights")
+ state_dict = torch.load(args.rm_path, map_location="cpu")
with strategy.model_init_context():
# configure model
- if args.model == 'gpt2':
+ if args.model == "gpt2":
initial_model = GPTActor(pretrained=args.pretrain)
- elif args.model == 'bloom':
+ elif args.model == "bloom":
initial_model = BLOOMActor(pretrained=args.pretrain)
- elif args.model == 'opt':
+ elif args.model == "opt":
initial_model = OPTActor(pretrained=args.pretrain)
- elif args.model == 'llama':
+ elif args.model == "llama":
initial_model = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@@ -51,13 +51,13 @@ def main(args):
else:
rm_model_name = args.rm_model
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@@ -68,24 +68,24 @@ def main(args):
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
- if args.model == 'gpt2':
+ if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'opt':
+ elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'llama':
+ elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
@@ -94,12 +94,12 @@ def main(args):
critic.load_state_dict(state_dict, strict=False)
del state_dict
- if args.strategy != 'colossalai_gemini':
+ if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
@@ -107,22 +107,22 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained(
- 'gpt2' if args.tokenizer is None else args.tokenizer)
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
- 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained(
- "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
- tokenizer.eos_token = '<\s>'
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -132,27 +132,25 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
- prompt_dataloader = DataLoader(prompt_dataset,
- shuffle=(prompt_sampler is None),
- sampler=prompt_sampler,
- batch_size=args.experience_batch_size)
-
- pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=args.pretrain_dataset,
- max_datasets_size=16384,
- max_length=args.max_input_len)
+ prompt_dataloader = DataLoader(
+ prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
+ )
+
+ pretrain_dataset = SupervisedDataset(
+ tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len
+ )
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
- pretrain_dataloader = DataLoader(pretrain_dataset,
- shuffle=(pretrain_sampler is None),
- sampler=pretrain_sampler,
- batch_size=args.ptx_batch_size)
+ pretrain_dataloader = DataLoader(
+ pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
+ )
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
- (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
- strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
+ )
# configure trainer
trainer = PPOTrainer(
@@ -173,50 +171,54 @@ def main(args):
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
- offload_inference_models=args.strategy != 'colossalai_gemini'
+ offload_inference_models=args.strategy != "colossalai_gemini",
)
- trainer.fit(prompt_dataloader=prompt_dataloader,
- pretrain_dataloader=pretrain_dataloader,
- num_episodes=args.num_episodes,
- num_collect_steps=args.num_collect_steps,
- num_update_steps=args.num_update_steps)
+ trainer.fit(
+ prompt_dataloader=prompt_dataloader,
+ pretrain_dataloader=pretrain_dataloader,
+ num_episodes=args.num_episodes,
+ num_collect_steps=args.num_collect_steps,
+ num_update_steps=args.num_update_steps,
+ )
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(actor_optim,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
- parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='colossalai_zero2',
- help='strategy to use')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--tokenizer', type=str, default=None)
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--rm_path', type=str, default=None)
- parser.add_argument('--rm_pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--num_collect_steps', type=int, default=10)
- parser.add_argument('--num_update_steps', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--ptx_batch_size', type=int, default=1)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--kl_coef', type=float, default=0.1)
- parser.add_argument('--ptx_coef', type=float, default=0.9)
- parser.add_argument('--max_input_len', type=int, default=96)
- parser.add_argument('--max_seq_len', type=int, default=128)
+ parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
+ parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
+ parser.add_argument(
+ "--strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
+ default="colossalai_zero2",
+ help="strategy to use",
+ )
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--rm_path", type=str, default=None)
+ parser.add_argument("--rm_pretrain", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--num_collect_steps", type=int, default=10)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--ptx_batch_size", type=int, default=1)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.9)
+ parser.add_argument("--max_input_len", type=int, default=96)
+ parser.add_argument("--max_seq_len", type=int, default=128)
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
index 190460bc20f6..a07f4b5ca812 100644
--- a/applications/Chat/examples/train_reward_model.py
+++ b/applications/Chat/examples/train_reward_model.py
@@ -24,24 +24,24 @@
def train(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
- if args.model == 'bloom':
+ if args.model == "bloom":
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'opt':
+ elif args.model == "opt":
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'gpt2':
+ elif args.model == "gpt2":
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'llama':
+ elif args.model == "llama":
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -53,36 +53,36 @@ def train(args):
model.load_state_dict(state_dict)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained(
- 'gpt2' if args.tokenizer is None else args.tokenizer)
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
- 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained(
- "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
- tokenizer.eos_token = '<\s>'
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=5e-6)
else:
optim = Adam(model.parameters(), lr=5e-6)
# configure loss function
- if args.loss_fn == 'log_sig':
+ if args.loss_fn == "log_sig":
loss_fn = LogSigLoss()
- elif args.loss_fn == 'log_exp':
+ elif args.loss_fn == "log_exp":
loss_fn = LogExpLoss()
else:
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
@@ -94,18 +94,18 @@ def train(args):
data = load_dataset(args.dataset)
if args.test:
- train_data = data['train'].select(range(20))
- eval_data = data['test'].select(range(5))
+ train_data = data["train"].select(range(20))
+ eval_data = data["test"].select(range(5))
else:
- train_data = data['train']
- eval_data = data['test']
- valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
+ train_data = data["train"]
+ eval_data = data["test"]
+ valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
- if args.dataset == 'Dahoas/rm-static':
+ if args.dataset == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
- elif args.dataset == 'Anthropic/hh-rlhf':
+ elif args.dataset == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
@@ -113,90 +113,99 @@ def train(args):
raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
- valid_sampler = DistributedSampler(valid_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
+ valid_sampler = DistributedSampler(
+ valid_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
valid_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
-
- valid_dataloader = DataLoader(valid_dataset,
- shuffle=(valid_sampler is None),
- sampler=valid_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
-
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
+
+ valid_dataloader = DataLoader(
+ valid_dataset,
+ shuffle=(valid_sampler is None),
+ sampler=valid_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
+
+ eval_dataloader = DataLoader(
+ eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
+ )
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
- model = strategy_dict['model']
- optim = strategy_dict['optimizer']
- lr_scheduler = strategy_dict['lr_scheduler']
- trainer = RewardModelTrainer(model=model,
- strategy=strategy,
- optim=optim,
- lr_scheduler=lr_scheduler,
- loss_fn=loss_fn,
- max_epochs=args.max_epochs)
+ model = strategy_dict["model"]
+ optim = strategy_dict["optimizer"]
+ lr_scheduler = strategy_dict["lr_scheduler"]
+ trainer = RewardModelTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ lr_scheduler=lr_scheduler,
+ loss_fn=loss_fn,
+ max_epochs=args.max_epochs,
+ )
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
- parser.add_argument('--tokenizer', type=str, default=None)
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--model_path', type=str, default=None)
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--dataset',
- type=str,
- choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
- default='Dahoas/rm-static')
- parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
- parser.add_argument('--save_path', type=str, default='rm_ckpt')
- parser.add_argument('--max_epochs', type=int, default=1)
- parser.add_argument('--batch_size', type=int, default=1)
- parser.add_argument('--max_len', type=int, default=512)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
- parser.add_argument('--test', type=bool, default=False)
+ parser.add_argument(
+ "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
+ )
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument(
+ "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
+ )
+ parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
+ parser.add_argument("--save_path", type=str, default="rm_ckpt")
+ parser.add_argument("--max_epochs", type=int, default=1)
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--max_len", type=int, default=512)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
+ parser.add_argument("--test", type=bool, default=False)
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index f068ea2bf5de..1729abb86a09 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -6,18 +6,18 @@
import torch.distributed as dist
from coati.dataset import SFTDataset, SupervisedDataset
from coati.models.bloom import BLOOMActor
+from coati.models.chatglm import ChatGLMActor
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
-from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler
@@ -28,14 +28,14 @@
def train(args):
# configure strategy
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2_cpu':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2_cpu":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
@@ -44,23 +44,15 @@ def train(args):
warnings.warn("Gradient checkpoint is disabled when using LoRA")
args.grad_checkpoint = False
with strategy.model_init_context():
- if args.model == 'bloom':
- model = BLOOMActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'opt':
- model = OPTActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'gpt2':
- model = GPTActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'llama':
- model = LlamaActor(pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- checkpoint=args.grad_checkpoint)
- elif args.model == 'chatglm':
+ if args.model == "bloom":
+ model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "opt":
+ model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "gpt2":
+ model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "llama":
+ model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "chatglm":
model = ChatGLMActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -68,144 +60,157 @@ def train(args):
model.to(torch.float16).to(torch.cuda.current_device())
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained(
- 'gpt2' if args.tokenizer is None else args.tokenizer)
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(
- 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained(
- "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(
- "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
- tokenizer.eos_token = '<\s>'
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
- elif args.model == 'chatglm':
+ elif args.model == "chatglm":
tokenizer = ChatGLMTokenizer.from_pretrained(
- "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
+ "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
+ )
else:
raise ValueError(f'Unsupported model "{args.model}"')
- if args.model == 'llama' and args.strategy == 'colossalai_gemini':
+ if args.model == "llama" and args.strategy == "colossalai_gemini":
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
- sub_module_name = '.'.join(name.split('.')[:-1])
- weight_name = name.split('.')[-1]
+ sub_module_name = ".".join(name.split(".")[:-1])
+ weight_name = name.split(".")[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
# configure dataset
- if args.dataset == 'yizhongw/self_instruct':
- train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
- eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
+ if args.dataset == "yizhongw/self_instruct":
+ train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
+ eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
else:
- train_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=args.dataset,
- max_datasets_size=args.max_datasets_size,
- max_length=args.max_len)
+ train_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=args.dataset,
+ max_datasets_size=args.max_datasets_size,
+ max_length=args.max_len,
+ )
eval_dataset = None
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
if eval_dataset is not None:
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=False,
- seed=42,
- drop_last=False,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=False,
+ seed=42,
+ drop_last=False,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
if eval_dataset is not None:
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
+ eval_dataloader = DataLoader(
+ eval_dataset,
+ shuffle=(eval_sampler is None),
+ sampler=eval_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
else:
eval_dataloader = None
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
- lr_scheduler = get_scheduler("cosine",
- optim,
- num_warmup_steps=math.ceil(max_steps * 0.03),
- num_training_steps=max_steps)
+ lr_scheduler = get_scheduler(
+ "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
+ )
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
- model = strategy_dict['model']
- optim = strategy_dict['optimizer']
- lr_scheduler = strategy_dict['lr_scheduler']
- trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- lr_scheduler=lr_scheduler,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps)
-
- trainer.fit(train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- logger=logger,
- use_wandb=args.use_wandb)
+ model = strategy_dict["model"]
+ optim = strategy_dict["optimizer"]
+ lr_scheduler = strategy_dict["lr_scheduler"]
+ trainer = SFTTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ lr_scheduler=lr_scheduler,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ )
+
+ trainer.fit(
+ train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb
+ )
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
- default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
- parser.add_argument('--tokenizer', type=str, default=None)
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--dataset', type=str, default=None)
- parser.add_argument('--max_datasets_size', type=int, default=None)
- parser.add_argument('--save_path', type=str, default='output')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--max_epochs', type=int, default=3)
- parser.add_argument('--batch_size', type=int, default=4)
- parser.add_argument('--max_len', type=int, default=512)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
- parser.add_argument('--lr', type=float, default=5e-6)
- parser.add_argument('--accumulation_steps', type=int, default=8)
- parser.add_argument('--use_wandb', default=False, action='store_true')
- parser.add_argument('--grad_checkpoint', default=False, action='store_true')
+ parser.add_argument(
+ "--strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
+ default="colossalai_zero2",
+ )
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--dataset", type=str, default=None)
+ parser.add_argument("--max_datasets_size", type=int, default=None)
+ parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--max_len", type=int, default=512)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py
index 438a1e3ef1c7..dbb5490a63dc 100644
--- a/applications/Chat/inference/benchmark.py
+++ b/applications/Chat/inference/benchmark.py
@@ -84,28 +84,34 @@ def evaluate(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- 'pretrained',
- help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
- parser.add_argument('--quant',
- choices=['8bit', '4bit'],
- default=None,
- help='Quantization mode. Default: None (no quantization, fp16).')
+ "pretrained",
+ help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
+ )
+ parser.add_argument(
+ "--quant",
+ choices=["8bit", "4bit"],
+ default=None,
+ help="Quantization mode. Default: None (no quantization, fp16).",
+ )
parser.add_argument(
- '--gptq_checkpoint',
+ "--gptq_checkpoint",
default=None,
- help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
- parser.add_argument('--gptq_group_size',
- type=int,
- default=128,
- help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
+ help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
+ )
+ parser.add_argument(
+ "--gptq_group_size",
+ type=int,
+ default=128,
+ help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
+ )
args = parser.parse_args()
- if args.quant == '4bit':
- assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
+ if args.quant == "4bit":
+ assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
- if args.quant == '4bit':
+ if args.quant == "4bit":
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
@@ -114,12 +120,12 @@ def evaluate(
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
- load_in_8bit=(args.quant == '8bit'),
+ load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
)
- if args.quant != '8bit':
- model.half() # seems to fix bugs for some users.
+ if args.quant != "8bit":
+ model.half() # seems to fix bugs for some users.
model.eval()
total_tokens = 0
@@ -129,7 +135,7 @@ def evaluate(
resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1)
total_tokens += tokens
print(f"Response: {resp}")
- print('\n----------------------------\n')
+ print("\n----------------------------\n")
duration = time() - start
- print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s')
- print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB')
+ print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s")
+ print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py
index 9443d4b99180..333262e538ac 100644
--- a/applications/Chat/inference/locustfile.py
+++ b/applications/Chat/inference/locustfile.py
@@ -1,26 +1,26 @@
-from json import JSONDecodeError
-
from locust import HttpUser, task
-samples = [[
- dict(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- dict(instruction='continue this talk', response=''),
-], [
- dict(instruction='Who is the best player in the history of NBA?', response=''),
-]]
+samples = [
+ [
+ dict(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ dict(instruction="continue this talk", response=""),
+ ],
+ [
+ dict(instruction="Who is the best player in the history of NBA?", response=""),
+ ],
+]
class GenerationUser(HttpUser):
-
@task
def generate(self):
for sample in samples:
- data = {'max_new_tokens': 64, 'history': sample}
- with self.client.post('/generate', json=data, catch_response=True) as response:
+ data = {"max_new_tokens": 64, "history": sample}
+ with self.client.post("/generate", json=data, catch_response=True) as response:
if response.status_code in (200, 406):
response.success()
else:
- response.failure('Response wrong')
+ response.failure("Response wrong")
diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py
index 9d6b7fabef54..7c6a61b9e7f2 100644
--- a/applications/Chat/inference/server.py
+++ b/applications/Chat/inference/server.py
@@ -16,7 +16,7 @@
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
-CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
+CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
MAX_LEN = 512
running_lock = Lock()
@@ -36,11 +36,11 @@ class GenerationTaskReq(BaseModel):
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# set CORS
-origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)
+origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
if origin_spec_from_env is not None:
# allow CORS from the specified origins
- origins = os.environ['CORS_ORIGIN'].split(',')
+ origins = os.environ["CORS_ORIGIN"].split(",")
else:
# allow CORS from all origins
origins = ["*"]
@@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
# TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = {
- 'max_generate_tokens': max_new_tokens,
- 'early_stopping': True,
- 'top_k': top_k,
- 'top_p': top_p,
- 'temperature': temperature,
- 'prepare_inputs_fn': model.prepare_inputs_for_generation,
- 'update_model_kwargs_fn': update_model_kwargs_fn,
+ "max_generate_tokens": max_new_tokens,
+ "early_stopping": True,
+ "top_k": top_k,
+ "top_p": top_p,
+ "temperature": temperature,
+ "prepare_inputs_fn": model.prepare_inputs_for_generation,
+ "update_model_kwargs_fn": update_model_kwargs_fn,
}
is_first_word = True
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
@@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
if is_first_word:
out_string = out_string.lstrip()
is_first_word = False
- elif current_sub_tokens[0].startswith('▁'):
+ elif current_sub_tokens[0].startswith("▁"):
# whitespace will be ignored by the frontend
- out_string = ' ' + out_string
+ out_string = " " + out_string
yield out_string
@@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator):
if await request.is_disconnected():
break
try:
- yield {'event': 'generate', 'data': next(generator)}
+ yield {"event": "generate", "data": next(generator)}
except StopIteration:
- yield {'event': 'end', 'data': ''}
+ yield {"event": "end", "data": ""}
break
-@app.post('/generate/stream')
-@limiter.limit('1/second')
+@app.post("/generate/stream")
+@limiter.limit("1/second")
def generate(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
event_source = event_generator(
- request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
+ request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
+ )
return EventSourceResponse(event_source)
-@app.post('/generate')
-@limiter.limit('1/second')
+@app.post("/generate")
+@limiter.limit("1/second")
def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock:
- output = model.generate(**inputs, **data.dict(exclude={'history'}))
+ output = model.generate(**inputs, **data.dict(exclude={"history"}))
output = output.cpu()
- prompt_len = inputs['input_ids'].size(1)
+ prompt_len = inputs["input_ids"].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
out_string = prompt_processor.postprocess_output(out_string)
@@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
return out_string
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- 'pretrained',
- help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
- parser.add_argument('--quant',
- choices=['8bit', '4bit'],
- default=None,
- help='Quantization mode. Default: None (no quantization, fp16).')
+ "pretrained",
+ help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
+ )
parser.add_argument(
- '--gptq_checkpoint',
+ "--quant",
+ choices=["8bit", "4bit"],
default=None,
- help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
- parser.add_argument('--gptq_group_size',
- type=int,
- default=128,
- help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
- parser.add_argument('--http_host', default='0.0.0.0')
- parser.add_argument('--http_port', type=int, default=7070)
- parser.add_argument('--profanity_file',
- default=None,
- help='Path to profanity words list. It should be a JSON file containing a list of words.')
+ help="Quantization mode. Default: None (no quantization, fp16).",
+ )
+ parser.add_argument(
+ "--gptq_checkpoint",
+ default=None,
+ help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
+ )
+ parser.add_argument(
+ "--gptq_group_size",
+ type=int,
+ default=128,
+ help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
+ )
+ parser.add_argument("--http_host", default="0.0.0.0")
+ parser.add_argument("--http_port", type=int, default=7070)
+ parser.add_argument(
+ "--profanity_file",
+ default=None,
+ help="Path to profanity words list. It should be a JSON file containing a list of words.",
+ )
args = parser.parse_args()
- if args.quant == '4bit':
- assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
+ if args.quant == "4bit":
+ assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
@@ -161,7 +170,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
- if args.quant == '4bit':
+ if args.quant == "4bit":
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
@@ -170,12 +179,12 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
- load_in_8bit=(args.quant == '8bit'),
+ load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
)
- if args.quant != '8bit':
- model.half() # seems to fix bugs for some users.
+ if args.quant != "8bit":
+ model.half() # seems to fix bugs for some users.
model.eval()
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py
index 23028d4959cb..9835e71894c6 100644
--- a/applications/Chat/inference/tests/test_chat_prompt.py
+++ b/applications/Chat/inference/tests/test_chat_prompt.py
@@ -3,41 +3,49 @@
from transformers import AutoTokenizer
from utils import ChatPromptProcessor, Dialogue
-CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
-tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH'])
+CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
+tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"])
samples = [
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 128,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 200,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 200,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 211,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 211,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n",
),
- ([
- Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
- ], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
+ (
+ [
+ Dialogue(instruction="Who is the best player in the history of NBA?", response=""),
+ ],
+ 128,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n",
),
]
@@ -49,5 +57,5 @@ def test_chat_prompt_processor():
assert prompt == result
-if __name__ == '__main__':
+if __name__ == "__main__":
test_chat_prompt_processor()
diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py
index e8e7b05ac719..af018adf6e9d 100644
--- a/applications/Chat/inference/utils.py
+++ b/applications/Chat/inference/utils.py
@@ -20,9 +20,9 @@
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-def prepare_logits_processor(top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None) -> LogitsProcessorList:
+def prepare_logits_processor(
+ top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
+) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
-def sample_streamingly(model: nn.Module,
- input_ids: torch.Tensor,
- max_generate_tokens: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> Generator:
-
+def sample_streamingly(
+ model: nn.Module,
+ input_ids: torch.Tensor,
+ max_generate_tokens: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> Generator:
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(max_generate_tokens):
- model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
- 'input_ids': input_ids
- }
+ model_inputs = (
+ prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
+ )
outputs = model(**model_inputs)
- next_token_logits = outputs['logits'][:, -1, :]
+ next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
@@ -107,25 +108,26 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
return model_kwargs
class Dialogue(BaseModel):
- instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
- response: str = Field(example='')
+ instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
+ response: str = Field(example="")
-def _format_dialogue(instruction: str, response: str = ''):
- return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
+def _format_dialogue(instruction: str, response: str = ""):
+ return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
-STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
+STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
class ChatPromptProcessor:
- SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
+ SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer
@@ -138,42 +140,48 @@ def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words:
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
if self.context_len is None:
- self.context_len = len(self.tokenizer(self.context)['input_ids'])
+ self.context_len = len(self.tokenizer(self.context)["input_ids"])
if self.dialogue_placeholder_len is None:
self.dialogue_placeholder_len = len(
- self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
+ self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
+ )
prompt = self.context
# the last dialogue must be in the prompt
last_dialogue = history.pop()
# the response of the last dialogue is empty
- assert last_dialogue.response == ''
- if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
- ['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
+ assert last_dialogue.response == ""
+ if (
+ len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
+ + max_new_tokens
+ + self.context_len
+ >= self.max_len
+ ):
# to avoid truncate placeholder, apply truncate to the original instruction
- instruction_truncated = self.tokenizer(last_dialogue.instruction,
- add_special_tokens=False,
- truncation=True,
- max_length=(self.max_len - max_new_tokens - self.context_len -
- self.dialogue_placeholder_len))['input_ids']
+ instruction_truncated = self.tokenizer(
+ last_dialogue.instruction,
+ add_special_tokens=False,
+ truncation=True,
+ max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
+ )["input_ids"]
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
prompt += _format_dialogue(instruction_truncated)
return prompt
- res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])
+ res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
rows = []
for dialogue in history[::-1]:
text = _format_dialogue(dialogue.instruction, dialogue.response)
- cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
+ cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
if res_len - cur_len < 0:
break
res_len -= cur_len
rows.insert(0, text)
- prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
+ prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt
def postprocess_output(self, output: str) -> str:
- output = STOP_PAT.sub('', output)
+ output = STOP_PAT.sub("", output)
return output.strip()
def has_censored_words(self, text: str) -> bool:
@@ -184,7 +192,6 @@ def has_censored_words(self, text: str) -> bool:
class LockedIterator:
-
def __init__(self, it, lock: Lock) -> None:
self.lock = lock
self.it = iter(it)
diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt
index eb1a77875acb..809fbd4bb86b 100644
--- a/applications/Chat/requirements-test.txt
+++ b/applications/Chat/requirements-test.txt
@@ -1,2 +1,2 @@
pytest
-colossalai==0.3.1
\ No newline at end of file
+colossalai==0.3.1
diff --git a/applications/Chat/setup.py b/applications/Chat/setup.py
index a285a6dff4bf..eb44b6203ef8 100644
--- a/applications/Chat/setup.py
+++ b/applications/Chat/setup.py
@@ -2,40 +2,42 @@
def fetch_requirements(path):
- with open(path, 'r') as fd:
+ with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
- with open('README.md', encoding='utf-8') as f:
+ with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
- with open('version.txt', 'r') as f:
+ with open("version.txt", "r") as f:
return f.read().strip()
setup(
- name='coati',
+ name="coati",
version=fetch_version(),
- packages=find_packages(exclude=(
- 'tests',
- 'benchmarks',
- '*.egg-info',
- )),
- description='Colossal-AI Talking Intelligence',
+ packages=find_packages(
+ exclude=(
+ "tests",
+ "benchmarks",
+ "*.egg-info",
+ )
+ ),
+ description="Colossal-AI Talking Intelligence",
long_description=fetch_readme(),
- long_description_content_type='text/markdown',
- license='Apache Software License 2.0',
- url='https://github.com/hpcaitech/Coati',
- install_requires=fetch_requirements('requirements.txt'),
- python_requires='>=3.6',
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech/Coati",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
classifiers=[
- 'Programming Language :: Python :: 3',
- 'License :: OSI Approved :: Apache Software License',
- 'Environment :: GPU :: NVIDIA CUDA',
- 'Topic :: Scientific/Engineering :: Artificial Intelligence',
- 'Topic :: System :: Distributed Computing',
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: System :: Distributed Computing",
],
)
diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py
index 3a3bf5b19cb8..e3058be2e67c 100644
--- a/applications/Chat/tests/test_checkpoint.py
+++ b/applications/Chat/tests/test_checkpoint.py
@@ -22,10 +22,7 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict:
return dict(input_ids=input_ids, attention_mask=attention_mask)
-def train_step(strategy: Strategy,
- actor: GPTActor,
- actor_optim: HybridAdam,
- batch_size: int = 8):
+def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
actor_output = actor(data["input_ids"], data["attention_mask"])
@@ -35,8 +32,7 @@ def train_step(strategy: Strategy,
strategy.optimizer_step(actor_optim)
-def run_test_checkpoint(strategy_name: str,
- shard: bool):
+def run_test_checkpoint(strategy_name: str, shard: bool):
if strategy_name == "ddp":
strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini":
@@ -60,11 +56,9 @@ def run_test_checkpoint(strategy_name: str,
dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0]
- model_path = os.path.join(
- rank0_dirname, "model" if shard else f"model.pt")
+ model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path, only_rank0=not shard)
- optim_path = os.path.join(
- rank0_dirname, "optim" if shard else "optim.pt")
+ optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
dist.barrier()
@@ -75,11 +69,7 @@ def run_test_checkpoint(strategy_name: str,
train_step(strategy, actor, actor_optim)
-def run_dist(rank: int,
- world_size: int,
- port: int,
- strategy_name: str,
- shard: bool):
+def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
@@ -93,13 +83,8 @@ def run_dist(rank: int,
@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
@pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use()
-def test_checkpoint(world_size: int,
- strategy_name: str,
- shard: bool):
- spawn(run_dist,
- world_size,
- strategy_name=strategy_name,
- shard=shard)
+def test_checkpoint(world_size: int, strategy_name: str, shard: bool):
+ spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard)
if __name__ == "__main__":
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
index f9dee1bae935..3de2cc528967 100644
--- a/applications/Chat/tests/test_dataset.py
+++ b/applications/Chat/tests/test_dataset.py
@@ -8,62 +8,40 @@
from coati.dataset.prompt_dataset import PromptDataset
from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+
SFT_DATASET = [
{
- "instruction":
- "Provide a list of the top 10 most popular mobile games in Asia",
- "input":
- "",
- "output":
- "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id":
- 0
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0,
},
{
- "instruction":
- "Please provide an action plan for reducing carbon footprint on a corporate level",
- "input":
- "",
- "output":
- "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
- "id":
- 1
+ "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
+ "input": "",
+ "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
+ "id": 1,
},
{
- "instruction":
- "Write a persuasive email to your boss explaining why you should have a pay raise",
- "input":
- "",
- "output":
- "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
- "id":
- 2
+ "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
+ "input": "",
+ "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
+ "id": 2,
},
]
PROMPT_DATASET = [
{
- "instruction":
- "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
- "id":
- 0
- },
- {
- "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
- "id": 1
- },
- {
- "instruction": "Write a persuasive essay arguing why homework should be banned in schools",
- "id": 2
- },
- {
- "instruction": "Create a chart comparing the statistics on student debt in the United States.",
- "id": 3
+ "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."',
+ "id": 0,
},
+ {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1},
+ {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2},
+ {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3},
]
@@ -120,10 +98,12 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
json.dump(PROMPT_DATASET, f)
tokenizer = make_tokenizer(model)
assert tokenizer.padding_side in ("left", "right")
- prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name),
- tokenizer=tokenizer,
- max_datasets_size=max_datasets_size,
- max_length=max_length)
+ prompt_dataset = PromptDataset(
+ data_path=os.path.join(tmp_dir, dataset_name),
+ tokenizer=tokenizer,
+ max_datasets_size=max_datasets_size,
+ max_length=max_length,
+ )
assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
for i in range(len(prompt_dataset)):
assert isinstance(prompt_dataset[i], dict)
@@ -137,14 +117,14 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
-@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"),
- ("Dahoas/rm-static", None)])
+@pytest.mark.parametrize(
+ ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)]
+)
@pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
data = load_dataset(dataset_path, data_dir=subset)
- assert max_datasets_size <= len(data["train"]) \
- and max_datasets_size <= len(data["test"])
+ assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"])
train_data = data["train"].select(range(max_datasets_size))
test_data = data["test"].select(range(max_datasets_size))
tokenizer = make_tokenizer(model)
@@ -162,8 +142,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert len(train_dataset) == len(test_dataset) == max_datasets_size
for i in range(max_datasets_size):
chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
- assert chosen_ids.shape == c_mask.shape == \
- reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
@@ -180,8 +159,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert torch.all(r_mask)
chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
- assert chosen_ids.shape == c_mask.shape == \
- reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
@@ -198,7 +176,6 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert torch.all(r_mask)
-
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@@ -214,10 +191,12 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
dataset_name = "sft_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
json.dump(SFT_DATASET, f)
- sft_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=os.path.join(tmp_dir, dataset_name),
- max_datasets_size=max_dataset_size,
- max_length=max_length)
+ sft_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=os.path.join(tmp_dir, dataset_name),
+ max_datasets_size=max_dataset_size,
+ max_length=max_length,
+ )
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
if isinstance(tokenizer, ChatGLMTokenizer):
@@ -227,20 +206,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
assert input_ids.shape == labels.shape == torch.Size([max_length])
-
+
ignore_mask = labels == IGNORE_INDEX
assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
return
-
+
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
- assert input_ids.shape == labels.shape == \
- attention_mask.shape == torch.Size([max_length])
+ assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length])
if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
@@ -254,13 +232,8 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
if __name__ == "__main__":
test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
- test_reward_dataset(model="gpt2",
- dataset_path="Anthropic/hh-rlhf",
- subset="harmless-base",
- max_datasets_size=8,
- max_length=256)
-
- test_prompt_dataset(model="opt",
- max_datasets_size=2,
- max_length=128)
+ test_reward_dataset(
+ model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256
+ )
+ test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128)
diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py
index 071e50b90e8e..d0ea3bbd2ff5 100644
--- a/applications/Chat/tests/test_experience.py
+++ b/applications/Chat/tests/test_experience.py
@@ -18,7 +18,7 @@
def get_data(batch_size: int, seq_len: int = 10) -> dict:
- input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
+ input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@@ -37,12 +37,12 @@ def make_and_consume_experience(strategy):
EXPERIENCE_BATCH_SIZE = 4
SAMPLE_BATCH_SIZE = 2
- if strategy == 'ddp':
+ if strategy == "ddp":
strategy = DDPStrategy()
- elif strategy == 'colossalai-zero2':
+ elif strategy == "colossalai-zero2":
strategy = LowLevelZeroStrategy()
- elif strategy == 'colossalai-gemini':
- strategy = GeminiStrategy(placement_policy='cuda')
+ elif strategy == "colossalai-gemini":
+ strategy = GeminiStrategy(placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
@@ -58,13 +58,11 @@ def make_and_consume_experience(strategy):
# experience of all ranks should be the same
for _ in range(2):
data = get_data(EXPERIENCE_BATCH_SIZE)
- assert gather_and_equal(data['input_ids'])
- assert gather_and_equal(data['attention_mask'])
- experience = experience_maker.make_experience(**data,
- do_sample=True,
- max_length=16,
- eos_token_id=50256,
- pad_token_id=50256)
+ assert gather_and_equal(data["input_ids"])
+ assert gather_and_equal(data["attention_mask"])
+ experience = experience_maker.make_experience(
+ **data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
+ )
assert gather_and_equal(experience.sequences)
assert gather_and_equal(experience.action_log_probs)
assert gather_and_equal(experience.values)
@@ -75,7 +73,7 @@ def make_and_consume_experience(strategy):
data_buffer.append(experience)
# data buffer's data should be the same
- buffer_size = torch.tensor([len(data_buffer)], device='cuda')
+ buffer_size = torch.tensor([len(data_buffer)], device="cuda")
assert gather_and_equal(buffer_size)
for item in data_buffer.items:
assert gather_and_equal(item.sequences)
@@ -88,7 +86,7 @@ def make_and_consume_experience(strategy):
# dataloader of each rank should have the same size and different batch
dataloader = strategy.setup_dataloader(data_buffer)
- dataloader_size = torch.tensor([len(dataloader)], device='cuda')
+ dataloader_size = torch.tensor([len(dataloader)], device="cuda")
assert gather_and_equal(dataloader_size)
for experience in dataloader:
assert not gather_and_equal(experience.sequences)
@@ -100,21 +98,21 @@ def make_and_consume_experience(strategy):
def run_dist(rank, world_size, port, strategy):
- os.environ['RANK'] = str(rank)
- os.environ['LOCAL_RANK'] = str(rank)
- os.environ['WORLD_SIZE'] = str(world_size)
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = str(port)
+ os.environ["RANK"] = str(rank)
+ os.environ["LOCAL_RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
make_and_consume_experience(strategy)
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [2])
-@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini'])
+@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
@rerun_if_address_is_in_use()
def test_experience(world_size, strategy):
spawn(run_dist, world_size, strategy=strategy)
-if __name__ == '__main__':
- test_experience(2, 'colossalai')
+if __name__ == "__main__":
+ test_experience(2, "colossalai")
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
index b98b3615cd28..b2551ff5c0de 100644
--- a/applications/Chat/tests/test_models.py
+++ b/applications/Chat/tests/test_models.py
@@ -6,15 +6,16 @@
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
+from coati.models.chatglm import ChatGLMActor
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
-from coati.models.chatglm import ChatGLMActor
+from coati.models.llama import LlamaActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
-from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
@@ -23,19 +24,24 @@
[
lambda: BLOOMActor(),
lambda: GPTActor(),
- # HACK: skip llama due to long execution time
- # lambda: LlamaActor(),
- lambda: OPTActor(),
- # lambda: ChatGLMActor(),
-])
-
-@pytest.mark.parametrize("generate_kwargs", [{
- "max_length": 64,
- "use_cache": True,
- "do_sample": True,
- "temperature": 1.0,
- "top_k": 50,
-}])
+ # HACK: skip llama due to long execution time
+ # lambda: LlamaActor(),
+ lambda: OPTActor(),
+ # lambda: ChatGLMActor(),
+ ],
+)
+@pytest.mark.parametrize(
+ "generate_kwargs",
+ [
+ {
+ "max_length": 64,
+ "use_cache": True,
+ "do_sample": True,
+ "temperature": 1.0,
+ "top_k": 50,
+ }
+ ],
+)
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
@@ -56,7 +62,7 @@ def test_utils():
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
- "action_mask": torch.randint(0, 2, (batch_size, num_labels))
+ "action_mask": torch.randint(0, 2, (batch_size, num_labels)),
}
fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size,)
@@ -66,9 +72,7 @@ def test_utils():
num_labels = 10
num_actions = 2
fn_input = {
- "output": {
- "logits": torch.randn((batch_size, seq_len, num_labels))
- },
+ "output": {"logits": torch.randn((batch_size, seq_len, num_labels))},
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions,
}
@@ -105,8 +109,9 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
- assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
- lora_model[i].lora_B @ lora_model[i].lora_A)
+ assert not torch.allclose(
+ old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
+ )
@pytest.mark.parametrize("batch_size", [8])
@@ -116,54 +121,60 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
[
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
- # HACK: skip llama due to long execution time
- # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
- lambda: (OPTActor(), OPTCritic(), OPTRM()),
- lambda: (ChatGLMActor(), None, None),
-])
+ # HACK: skip llama due to long execution time
+ # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
+ lambda: (OPTActor(), OPTCritic(), OPTRM()),
+ lambda: (ChatGLMActor(), None, None),
+ ],
+)
@torch.no_grad()
-def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
- batch_size: int,
- seq_len: int):
+def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
rm_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
- "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
actor, critic, rm = models_maker()
if isinstance(actor, ChatGLMActor):
actor = actor.float()
- tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
+ tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
- actor_input ={
- "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
- "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
- }
+ actor_input = {
+ "input_ids": torch.cat(
+ (
+ torch.randint(0, 100, (batch_size, seq_len // 2)),
+ chatglm_special_token,
+ torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
+ ),
+ dim=1,
+ ),
+ "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
+ }
assert isinstance(actor, Actor)
- base_actor_model = get_base_model(actor)
+ get_base_model(actor)
actor_output = actor(**actor_input)
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
if critic:
assert isinstance(critic, Critic)
- base_critic_model = get_base_model(critic)
+ get_base_model(critic)
critic_output = critic(**critic_input)
- assert critic_output.shape == (batch_size, )
-
+ assert critic_output.shape == (batch_size,)
+
if rm:
assert isinstance(rm, RewardModel)
- base_rm_model = get_base_model(rm)
+ get_base_model(rm)
rm_output = rm(**rm_input)
- assert rm_output.shape == (batch_size, )
+ assert rm_output.shape == (batch_size,)
@pytest.mark.parametrize("batch_size", [16])
@@ -173,39 +184,59 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
- "labels": torch.randint(0, num_labels, (batch_size, seq_len))
+ "labels": torch.randint(0, num_labels, (batch_size, seq_len)),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = PolicyLoss()
loss_input = {
- "log_probs": torch.randn(batch_size,),
- "old_log_probs": torch.randn(batch_size,),
- "advantages": torch.randn(batch_size,)
+ "log_probs": torch.randn(
+ batch_size,
+ ),
+ "old_log_probs": torch.randn(
+ batch_size,
+ ),
+ "advantages": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = ValueLoss()
loss_input = {
- "values": torch.randn(batch_size,),
- "old_values": torch.randn(batch_size,),
- "reward": torch.randn(batch_size,)
+ "values": torch.randn(
+ batch_size,
+ ),
+ "old_values": torch.randn(
+ batch_size,
+ ),
+ "reward": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = LogSigLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size,),
- "reject_reward": torch.randn(batch_size,),
+ "chosen_reward": torch.randn(
+ batch_size,
+ ),
+ "reject_reward": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
loss = LogExpLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size,),
- "reject_reward": torch.randn(batch_size,),
+ "chosen_reward": torch.randn(
+ batch_size,
+ ),
+ "reject_reward": torch.randn(
+ batch_size,
+ ),
}
- loss_output = loss(**loss_input)
+ loss(**loss_input)
if __name__ == "__main__":
@@ -218,4 +249,4 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int):
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
- test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
+ test_loss(batch_size=8, seq_len=128, num_labels=100)
diff --git a/colossalai/__init__.py b/colossalai/__init__.py
index fa6f72a605c0..7da55590305b 100644
--- a/colossalai/__init__.py
+++ b/colossalai/__init__.py
@@ -6,7 +6,7 @@
except ModuleNotFoundError:
# this will only happen if the user did not run `pip install`
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice
- __version__ = '0.0.0'
- print('please install Colossal-AI from https://www.colossalai.org/download or from source')
+ __version__ = "0.0.0"
+ print("please install Colossal-AI from https://www.colossalai.org/download or from source")
-__all__ = ['launch', 'launch_from_openmpi', 'launch_from_slurm', 'launch_from_torch', '__version__']
+__all__ = ["launch", "launch_from_openmpi", "launch_from_slurm", "launch_from_torch", "__version__"]
diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py
index 4049be79c70f..e8ba88b0406d 100644
--- a/colossalai/_analyzer/_subclasses/_meta_registration.py
+++ b/colossalai/_analyzer/_subclasses/_meta_registration.py
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Union
import torch
from packaging import version
@@ -24,25 +24,23 @@
def new(*args, **kwargs):
- return orig_empty(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty(*args, **kwargs, device=torch.device("meta"))
def new_strided(*args, **kwargs):
- return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
def new_like(*args, **kwargs):
- return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
def register_meta(op, register_dispatcher=True):
-
def wrapper(f):
-
def add_func(op):
meta_table[op] = f
if register_dispatcher:
- name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
+ name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -54,7 +52,7 @@ def add_func(op):
return wrapper
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
@@ -69,7 +67,6 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
-
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -146,7 +143,8 @@ def calc_conv_nd_return_shape(
kernel_size[i],
stride[i],
output_padding_list[i],
- ))
+ )
+ )
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -180,19 +178,39 @@ def pick_memory_format():
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
- out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
+ out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
- def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
- padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
- *extra_args):
+ def meta__conv(
+ input_tensor: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ is_transposed: bool,
+ output_padding: List[int],
+ groups: int,
+ *extra_args,
+ ):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
- def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
- padding, dilation, transposed, output_padding, groups, output_mask):
+ def meta_conv_backward(
+ grad_output: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias_sizes,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ output_mask,
+ ):
return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -224,7 +242,6 @@ def meta_cuda_rnn(
batch_sizes,
dropout_state,
):
-
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -240,8 +257,11 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
- out_shape = ([mini_batch, seq_length, out_size *
- num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
+ out_shape = (
+ [mini_batch, seq_length, out_size * num_directions]
+ if batch_first
+ else [seq_length, mini_batch, out_size * num_directions]
+ )
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -257,15 +277,21 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
- def meta_cudnn_rnn_backward(input: torch.Tensor,
- weight: torch.Tensor,
- weight_stride0: int,
- hx: torch.Tensor,
- cx: Optional[torch.Tensor] = None,
- *args,
- **kwargs):
- return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
- ()) # (grad_input, grad_weight, grad_hx, grad_cx)
+ def meta_cudnn_rnn_backward(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_stride0: int,
+ hx: torch.Tensor,
+ cx: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ):
+ return (
+ new_like(input),
+ new_like(weight),
+ new_like(hx),
+ new_like(cx) if cx is not None else new(()),
+ ) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@@ -278,7 +304,7 @@ def meta_cudnn_rnn_backward(input: torch.Tensor,
aten.hardtanh_backward.default,
]
- if version.parse(torch.__version__) < version.parse('2.0.0'):
+ if version.parse(torch.__version__) < version.parse("2.0.0"):
_unregistered_ewise += [
aten.prelu_backward.default,
]
@@ -296,37 +322,61 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
- def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, train, eps, output_mask):
- return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+ def meta_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ train,
+ eps,
+ output_mask,
+ ):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
- return new_like(input), new((n_input)), new((n_input)), new(
- (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
+ return (
+ new_like(input),
+ new((n_input)),
+ new((n_input)),
+ new((0), dtype=torch.uint8),
+ ) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
- def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, eps, reserve):
- return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+ def meta_cudnn_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ eps,
+ reserve,
+ ):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
- return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
+ return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
- def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
- grad_input_mask):
- return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
+ def meta_ln_backward(
+ dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
+ ):
+ return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
@@ -355,8 +405,9 @@ def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Te
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
- def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
- scale_grad_by_freq):
+ def meta_embedding_dense_backward(
+ grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
+ ):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
@@ -364,14 +415,14 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
- return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
+ return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
- return new_like(grad) # (grad_in)
+ return new_like(grad) # (grad_in)
- if version.parse(torch.__version__) < version.parse('1.13.0'):
+ if version.parse(torch.__version__) < version.parse("1.13.0"):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
@@ -385,24 +436,28 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
- assert index.dtype in [torch.long, torch.int8, torch.bool],\
- "tensors used as indices must be long, byte or bool tensors"
+ assert index.dtype in [
+ torch.long,
+ torch.int8,
+ torch.bool,
+ ], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
- assert index.shape[j] == self.shape[
- k +
- j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
+ assert (
+ index.shape[j] == self.shape[k + j]
+ ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
- assert len(
- indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
+ assert (
+ len(indices) <= self.ndim
+ ), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs
diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py
index b3ec98f0811f..503981409cca 100644
--- a/colossalai/_analyzer/_subclasses/_monkey_patch.py
+++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py
@@ -1,5 +1,4 @@
import torch
-import torch.distributed as dist
from packaging import version
__all__ = [
@@ -48,7 +47,7 @@
"scatter",
]
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py
index 59991dc50912..9d52c5593bb8 100644
--- a/colossalai/_analyzer/_subclasses/flop_tensor.py
+++ b/colossalai/_analyzer/_subclasses/flop_tensor.py
@@ -8,7 +8,7 @@
from enum import Enum, auto
from functools import partial, reduce
from numbers import Number
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, List, Union
import torch
from packaging import version
@@ -36,15 +36,15 @@ def _format_flops(flop):
B = 1e9
T = 1e12
if flop < K:
- return f'{flop:.2f}'
+ return f"{flop:.2f}"
elif flop < M:
- return f'{flop / K:.2f}K'
+ return f"{flop / K:.2f}K"
elif flop < B:
- return f'{flop / M:.2f}M'
+ return f"{flop / M:.2f}M"
elif flop < T:
- return f'{flop / B:.2f}B'
+ return f"{flop / B:.2f}B"
else:
- return f'{flop / T:.2f}T'
+ return f"{flop / T:.2f}T"
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
Returns:
Number: The total number of floating point operations (FWD + BWD).
"""
- maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
- or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
+ maybe_inplace = (
+ getattr(module, "inplace", False)
+ or kwargs.get("inplace", False)
+ or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
+ )
class DummyModule(torch.nn.Module):
-
def __init__(self, func):
super().__init__()
self.func = func
@@ -74,21 +76,20 @@ def forward(self, *args, **kwargs):
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
flop_counts = defaultdict(lambda: defaultdict(int))
- parents = ['Global']
+ parents = ["Global"]
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
class FlopTensor(MetaTensor):
_tensor: torch.Tensor
def __repr__(self):
- name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
+ name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
rs = super().__torch_dispatch__(func, types, args, kwargs)
@@ -115,9 +116,7 @@ def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def create_backwards_push(name):
-
class PushState(torch.autograd.Function):
-
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -134,9 +133,7 @@ def backward(ctx, *grad_outs):
return PushState.apply
def create_backwards_pop(name):
-
class PopState(torch.autograd.Function):
-
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -147,14 +144,13 @@ def forward(ctx, *args):
@staticmethod
def backward(ctx, *grad_outs):
nonlocal parents
- assert (parents[-1] == name)
+ assert parents[-1] == name
parents.pop()
return grad_outs
return PopState.apply
def enter_module(name):
-
def f(module, inputs):
nonlocal parents
parents.append(name)
@@ -165,10 +161,9 @@ def f(module, inputs):
return f
def exit_module(name):
-
def f(module, inputs, outputs):
nonlocal parents
- assert (parents[-1] == name)
+ assert parents[-1] == name
parents.pop()
outputs = normalize_tuple(outputs)
return create_backwards_push(name)(*outputs)
@@ -189,7 +184,7 @@ def display_flops():
for mod in flop_counts.keys():
print(f"Module: ", mod)
for k, v in flop_counts[mod].items():
- print('\t', k, _format_flops(v))
+ print("\t", k, _format_flops(v))
print()
def detach_variables(r):
@@ -201,7 +196,7 @@ def detach_variables(r):
def wrap(r):
if isinstance(r, torch.Tensor):
- data_ptr_fn = getattr(r, '_tensor', r).data_ptr
+ data_ptr_fn = getattr(r, "_tensor", r).data_ptr
r = FlopTensor(detach_variables(r))
if maybe_inplace:
r = r + 0
@@ -375,8 +370,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
- has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
- 'shape') else inputs[affine_arg_index]
+ has_affine = (
+ inputs[affine_arg_index].shape is not None
+ if hasattr(inputs[affine_arg_index], "shape")
+ else inputs[affine_arg_index]
+ )
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
- return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
+ return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -420,33 +418,30 @@ def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
def zero_flop_jit(*args):
"""
- Count flops for zero flop layers.
+ Count flops for zero flop layers.
"""
return 0
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
flop_mapping = {
- # gemm
+ # gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
-
- # convolution
+ # convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
-
- # normalization
+ # normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
-
- # pooling
+ # pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
@@ -469,7 +464,7 @@ def zero_flop_jit(*args):
}
ewise_flop_aten = [
- # basic op
+ # basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -485,8 +480,7 @@ def zero_flop_jit(*args):
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
-
- # activation op
+ # activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -509,15 +503,12 @@ def zero_flop_jit(*args):
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
-
- # dropout
+ # dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
-
- # distribution
+ # distribution
aten.bernoulli_.float,
-
- # where
+ # where
aten.where.self,
]
for op in ewise_flop_aten:
diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py
index 2bc212938ee0..8be97d01343e 100644
--- a/colossalai/_analyzer/_subclasses/meta_tensor.py
+++ b/colossalai/_analyzer/_subclasses/meta_tensor.py
@@ -3,12 +3,12 @@
import torch
import torch.distributed as dist
-from torch.types import _bool, _device, _dtype
-from torch.utils._pytree import tree_flatten, tree_map
+from torch.types import _device
+from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
-__all__ = ['MetaTensor', 'MetaTensorMode']
+__all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None):
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch
def _assert_alias(func):
- return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
- )
+ return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
class MetaTensor(torch.Tensor):
@@ -65,14 +64,15 @@ def __new__(cls, elem, device=None, data_ptr_fn=None):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
- device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
- requires_grad=requires_grad) # deceive the frontend for aten selections
+ device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
+ requires_grad=requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
@@ -81,7 +81,7 @@ def __new__(cls, elem, device=None, data_ptr_fn=None):
return r
def __repr__(self):
- name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
+ name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@@ -97,15 +97,15 @@ def unwrap(x):
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
- if 'device' in kwargs:
- device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
@@ -143,21 +143,21 @@ def replace(x):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
- return torch.device('meta')
+ return torch.device("meta")
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
- if self.device.type == 'cpu':
+ if self.device.type == "cpu":
return self.to(*args, **kwargs)
- return self.to(*args, device='cpu', **kwargs)
+ return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
- return self.to(device='cuda:0', non_blocking=non_blocking)
+ return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
"""
def __init__(self):
- self.torch_overrides = {} # override torch.xxx
- self.dist_overrides = {} # override torch.distributed.xxx
+ self.torch_overrides = {} # override torch.xxx
+ self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
-
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
- return MetaTensor(orig_new(*args, **{
- **kwargs, 'device': 'meta'
- }),
- device=kwargs.get('device', torch.device('cpu')))
+ return MetaTensor(
+ orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
+ )
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)
diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py
index 41d74f2e3719..cd244b22cac0 100644
--- a/colossalai/_analyzer/fx/codegen.py
+++ b/colossalai/_analyzer/fx/codegen.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Any, Dict, List, Tuple
import torch
@@ -22,7 +22,7 @@
import colossalai
from colossalai.fx._compatibility import compatibility
-_register_custom_builtin('colossalai', 'import colossalai', colossalai)
+_register_custom_builtin("colossalai", "import colossalai", colossalai)
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
- outputs = ', '.join(output_vars)
- inputs = ', '.join(input_vars)
- return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
+ outputs = ", ".join(output_vars)
+ inputs = ", ".join(input_vars)
+ return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
- if len(node.meta['info'].activation_checkpoint) > ckpt_level:
- return node.meta['info'].activation_checkpoint[ckpt_level] is not None
+ if len(node.meta["info"].activation_checkpoint) > ckpt_level:
+ return node.meta["info"].activation_checkpoint[ckpt_level] is not None
return True
@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None
for idx, node in enumerate(node_list):
- if len(node.meta['info'].activation_checkpoint) > ckpt_level:
- act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
+ if len(node.meta["info"].activation_checkpoint) > ckpt_level:
+ act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
return ckpt_regions
-def emit_ckpt_func(body,
- ckpt_func,
- node_list: List[Node],
- emit_node_func,
- delete_unused_value_func,
- ckpt_level=0,
- in_ckpt=False):
+def emit_ckpt_func(
+ body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False
+):
"""Emit ckpt function in nested way
Args:
@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
- label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
+ label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
- if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
+ if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
- emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
- ckpt_level + 1, True)
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(
+ ckpt_func,
+ ckpt_func_buffer,
+ ckpt_node_list,
+ emit_node_func,
+ delete_unused_value_func,
+ ckpt_level + 1,
+ True,
+ )
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
- usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
+ usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# process ckpt_regions
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen):
-
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
@@ -251,7 +253,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -259,7 +261,7 @@ def add_global(name_hint: str, obj: Any):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -281,16 +283,16 @@ def add_global(name_hint: str, obj: Any):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
- if hasattr(o, '__args__'):
+ if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
@@ -309,19 +311,18 @@ def type_repr(o: Any):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
- if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+ if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
- args_s = ', '.join(_get_repr(a) for a in args)
- kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+ args_s = ", ".join(_get_repr(a) for a in args)
+ kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
- return f'{args_s}, {kwargs_s}'
+ return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
@@ -347,82 +348,94 @@ def delete_unused_values(user: Node, body):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
- body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
- f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
@@ -432,13 +445,13 @@ def emit_node(node: Node, body):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
@@ -447,11 +460,11 @@ def emit_node(node: Node, body):
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
- prologue = ''.join(ckpt_func) + prologue
+ prologue = "".join(ckpt_func) + prologue
prologue = prologue
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py
index 1fdedd758c01..9d3999e322b9 100644
--- a/colossalai/_analyzer/fx/graph_module.py
+++ b/colossalai/_analyzer/fx/graph_module.py
@@ -13,6 +13,7 @@
try:
from torch.fx.graph import _PyTreeCodeGen
+
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
@@ -24,7 +25,6 @@
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
-
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
@@ -50,12 +50,14 @@ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
# constituent substrings of the error message
tb_repr = traceback.format_exc()
- custom_msg = ("Call using an FX-traced Module, "
- f"line {err_lineno} of the traced Module's "
- "generated forward function:")
- before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
+ custom_msg = (
+ "Call using an FX-traced Module, "
+ f"line {err_lineno} of the traced Module's "
+ "generated forward function:"
+ )
+ before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
- err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
+ err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
@@ -65,11 +67,14 @@ def __call__(self, obj, *args, **kwargs):
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
- return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
+ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
- topmost_framesummary: traceback.FrameSummary = \
- traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
+ topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
+ traceback.walk_tb(e.__traceback__)
+ )[
+ -1
+ ] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
code.
"""
- def __init__(self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: torch.fx.Graph,
- class_name: str = 'GraphModule'):
+ def __init__(
+ self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
+ ):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
@@ -134,7 +138,7 @@ def recompile(self) -> PythonCode:
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module='self')
+ python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -157,8 +161,8 @@ def recompile(self) -> PythonCode:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if '_wrapped_call' not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
+ if "_wrapped_call" not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -182,7 +186,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / 'state_dict.pt')
+ torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -208,10 +212,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
- module_file = folder / f'{module_name}.pt'
+ module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
- module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
+ module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -228,12 +232,14 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / 'module.py'
+ module_file = folder / "module.py"
module_file.write_text(model_str)
- init_file = folder / '__init__.py'
- init_file.write_text('from .module import *')
+ init_file = folder / "__init__.py"
+ init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
- warnings.warn("Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}")
+ warnings.warn(
+ "Was not able to save the following children modules as reprs -"
+ f"saved as pickled files instead: {blobified_modules}"
+ )
diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py
index fbe8400a437e..d2671787ea63 100644
--- a/colossalai/_analyzer/fx/node_util.py
+++ b/colossalai/_analyzer/fx/node_util.py
@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
-from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torch
-from torch.autograd.profiler_util import _format_memory, _format_time
-from torch.fx import Graph, GraphModule, Node
+from torch.autograd.profiler_util import _format_memory
+from torch.fx import Node
from colossalai._analyzer.envs import MeshConfig
@@ -85,12 +85,12 @@ class MetaInfo:
node: Node
# directory
- mod_dir: str = ''
+ mod_dir: str = ""
# ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward
- global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
- curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
+ global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
+ curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
# should be updated after each graph manipulation
# ============================== Update ====================================
@@ -100,7 +100,7 @@ class MetaInfo:
inputs: Tuple[torch.Tensor] = ()
outputs: Tuple[torch.Tensor] = ()
- is_alias: Tuple[bool] = () # whether the output is an alias of input
+ is_alias: Tuple[bool] = () # whether the output is an alias of input
# compute cost
fwd_flop: Optional[int] = 0
@@ -112,29 +112,29 @@ class MetaInfo:
# should keep the same whenever manipulated
# ============================= Invariant ==================================
- activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
+ activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
- sharding_spec: str = 'RR'
+ sharding_spec: str = "RR"
def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__
# if initialized, return the existing one
# should disable the __init__ function
- if node.meta.get('info', None) is not None:
+ if node.meta.get("info", None) is not None:
def _dummy(self, *args, **kwargs):
- if getattr(self, '_is_init', False):
+ if getattr(self, "_is_init", False):
self._is_init = True
orig_init(self, *args, **kwargs)
cls.__init__ = orig_init
cls.__init__ = _dummy
- return node.meta['info']
+ return node.meta["info"]
return super().__new__(cls)
def __post_init__(self):
- self.node.meta['info'] = self
+ self.node.meta["info"] = self
@property
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
@@ -188,24 +188,26 @@ def backward_size(self):
return compute_size_in_bytes(self.inputs)
def __repr__(self):
- s = f'Node {self.node.name}'
+ s = f"Node {self.node.name}"
if self.parameters:
- s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
+ s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
if self.buffers:
- s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
+ s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
if self.output_size:
- s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
+ s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
# if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size:
- s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
+ s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
if self.backward_size:
- s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
- s += f'\n\tfwd_flop = {self.fwd_flop}'\
- f'\n\tbwd_flop = {self.bwd_flop}'\
- f'\n\tfwd_comm = {self.fwd_comm}'\
- f'\n\tbwd_comm = {self.bwd_comm}'\
- f'\n\tto_recompute = {self.to_recompute}'\
- f'\n\tto_offload = {self.to_offload}'\
- f'\n\tsharding_spec = {self.sharding_spec}'
+ s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
+ s += (
+ f"\n\tfwd_flop = {self.fwd_flop}"
+ f"\n\tbwd_flop = {self.bwd_flop}"
+ f"\n\tfwd_comm = {self.fwd_comm}"
+ f"\n\tbwd_comm = {self.bwd_comm}"
+ f"\n\tto_recompute = {self.to_recompute}"
+ f"\n\tto_offload = {self.to_offload}"
+ f"\n\tsharding_spec = {self.sharding_spec}"
+ )
return s
diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py
index c3e760b31e96..158ebce219cd 100644
--- a/colossalai/_analyzer/fx/passes/graph_profile.py
+++ b/colossalai/_analyzer/fx/passes/graph_profile.py
@@ -1,8 +1,8 @@
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
import torch.fx
-from torch.autograd.profiler_util import _format_memory, _format_time
+from torch.autograd.profiler_util import _format_memory
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
@@ -13,14 +13,14 @@
def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string"""
if flops > 1e12:
- return f'{flops / 1e12:.2f} TFLOPs'
+ return f"{flops / 1e12:.2f} TFLOPs"
elif flops > 1e9:
- return f'{flops / 1e9:.2f} GFLOPs'
+ return f"{flops / 1e9:.2f} GFLOPs"
elif flops > 1e6:
- return f'{flops / 1e6:.2f} MFLOPs'
+ return f"{flops / 1e6:.2f} MFLOPs"
elif flops > 1e3:
- return f'{flops / 1e3:.2f} kFLOPs'
- return f'{flops} FLOPs'
+ return f"{flops / 1e3:.2f} kFLOPs"
+ return f"{flops} FLOPs"
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
+
_profileable = [
- 'call_function',
- 'call_module',
- 'call_method',
+ "call_function",
+ "call_module",
+ "call_method",
]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
@@ -77,14 +78,13 @@ def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_pr
self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes:
-
- self.run_node(node) # No need to store.
+ self.run_node(node) # No need to store.
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
- if node.op == 'output':
+ if node.op == "output":
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
@@ -133,9 +133,11 @@ def summary(self) -> str:
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
@@ -145,36 +147,38 @@ def summary(self) -> str:
node: Node
n_info = MetaInfo(node)
last_n_info = last_n_info or n_info
- node_summaries.append([
- node.op,
- str(node),
- _format_memory(n_info.accumulate_size),
- _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
- _format_memory(n_info.output_size),
- _format_memory(n_info.temp_size),
- _format_memory(n_info.param_size),
- _format_memory(n_info.backward_size),
- _format_flops(n_info.fwd_flop),
- _format_flops(n_info.bwd_flop),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ _format_memory(n_info.accumulate_size),
+ _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
+ _format_memory(n_info.output_size),
+ _format_memory(n_info.temp_size),
+ _format_memory(n_info.param_size),
+ _format_memory(n_info.backward_size),
+ _format_flops(n_info.fwd_flop),
+ _format_flops(n_info.bwd_flop),
+ ]
+ )
last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Accumulate size',
- 'Incremental size',
- 'Output size',
- 'Temp size',
- 'Param size',
- 'Backward size',
- 'Fwd FLOPs',
- 'Bwd FLOPs',
+ "Op type",
+ "Op",
+ "Accumulate size",
+ "Incremental size",
+ "Output size",
+ "Temp size",
+ "Param size",
+ "Backward size",
+ "Fwd FLOPs",
+ "Bwd FLOPs",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
class CommunicationProfiler(GraphProfiler):
@@ -222,6 +226,7 @@ class with the ``@register_flop_count_impl`` decorator:
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
+
_custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any:
@@ -246,11 +251,13 @@ def run_node(self, n: torch.fx.Node) -> Any:
(
n_info.fwd_flop,
n_info.bwd_flop,
- ) = getattr(self, n.op)(n.target, args, kwargs)
+ ) = getattr(
+ self, n.op
+ )(n.target, args, kwargs)
except Exception as e:
raise RuntimeError(
- f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
- f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
+ f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
+ f"Please refer to function's docstring to register the relevant profile_impl for this node!"
) from e
# retain the autograd graph
@@ -259,7 +266,7 @@ def run_node(self, n: torch.fx.Node) -> Any:
return _denormalize_tuple(n_info.outputs)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
@@ -283,7 +290,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di
else:
return flop_count(target, *args, **kwargs)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the profiling result.
@@ -301,7 +308,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the profiling result.
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
Returns:
GraphModule: The same GraphModule with profiling information
"""
- for profiler_cls in (FlopProfiler,
- # CommunicationProfiler, # TODO: add communication profiling
- ):
+ for profiler_cls in (
+ FlopProfiler,
+ # CommunicationProfiler, # TODO: add communication profiling
+ ):
profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module))
diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py
index 23e83013e02f..8d44f1d4b59d 100644
--- a/colossalai/_analyzer/fx/passes/shape_prop.py
+++ b/colossalai/_analyzer/fx/passes/shape_prop.py
@@ -54,7 +54,7 @@ def _current_device(module):
try:
return next(module.parameters()).device
except StopIteration:
- return torch.device('cpu')
+ return torch.device("cpu")
@compatibility(is_backward_compatible=False)
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
+
_custom_dispatch_func = {}
_mode = MetaTensorMode()
@@ -115,15 +116,14 @@ def run_node(self, n: torch.fx.Node) -> Any:
r = getattr(self, n.op)(n.target, args, kwargs)
def unwrap_fn(elem):
-
def _convert_meta(t: torch.Tensor):
- if t.device == 'meta':
+ if t.device == "meta":
return t
else:
- return t.to('meta')
+ return t.to("meta")
if isinstance(elem, MetaTensor):
- if getattr(self, '_is_param', False):
+ if getattr(self, "_is_param", False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
@@ -139,21 +139,24 @@ def _convert_meta(t: torch.Tensor):
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
- if n.op == 'call_module':
+ if n.op == "call_module":
submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else:
- n_info.parameters.update({
- k.name: MetaTensor(v)
- for k, v in zip(n.args, args)
- if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
- })
+ n_info.parameters.update(
+ {
+ k.name: MetaTensor(v)
+ for k, v in zip(n.args, args)
+ if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
+ }
+ )
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
- n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
- tuple(v for v in kwargs.values() if is_pure_tensor(v))
+ n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
+ v for v in kwargs.values() if is_pure_tensor(v)
+ )
# align with SPMD
if isinstance(r, (tuple, list)):
@@ -168,7 +171,7 @@ def _convert_meta(t: torch.Tensor):
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r
- def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
@@ -197,7 +200,7 @@ def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[st
else:
return res
- def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
@@ -218,7 +221,8 @@ def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str,
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
- args[0], torch.nn.parameter.Parameter):
+ args[0], torch.nn.parameter.Parameter
+ ):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py
index dd7f22c6c98a..5732a6665f78 100644
--- a/colossalai/_analyzer/fx/symbolic_profile.py
+++ b/colossalai/_analyzer/fx/symbolic_profile.py
@@ -1,5 +1,3 @@
-import torch
-import torch.fx
from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
@@ -7,7 +5,6 @@
def register_flop_count_impl(func):
-
def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl
return impl
@@ -16,7 +13,6 @@ def wrapper(impl):
def register_shape_impl(func):
-
def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl
return impl
diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py
index 1e75b47ca5b0..b8b83282b42c 100644
--- a/colossalai/_analyzer/fx/tracer/bias_addition.py
+++ b/colossalai/_analyzer/fx/tracer/bias_addition.py
@@ -12,7 +12,7 @@
__all__ = []
-@register_tracer_impl(F.linear, name='_bias_addition_impl')
+@register_tracer_impl(F.linear, name="_bias_addition_impl")
def linear_impl(input, weight, bias=None):
if bias is None:
return F.linear(input, weight)
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
return F.linear(input, weight) + bias
-@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv1d, name="_bias_addition_impl")
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1))
+ (-1, 1)
+ )
-@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv2d, name="_bias_addition_impl")
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1, 1))
+ (-1, 1, 1)
+ )
-@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv3d, name="_bias_addition_impl")
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1, 1, 1))
-
-
-@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
-def conv_transpose1d_impl(input,
- weight,
- bias=None,
- stride=_single(1),
- padding=_single(0),
- output_padding=_single(0),
- groups=1,
- dilation=_single(1)):
+ (-1, 1, 1, 1)
+ )
+
+
+@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
+def conv_transpose1d_impl(
+ input,
+ weight,
+ bias=None,
+ stride=_single(1),
+ padding=_single(0),
+ output_padding=_single(0),
+ groups=1,
+ dilation=_single(1),
+):
if bias is None:
- return F.conv_transpose1d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose1d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose1d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1))
-
-
-@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
-def conv_transpose2d_impl(input,
- weight,
- bias=None,
- stride=_pair(1),
- padding=_pair(0),
- output_padding=_pair(0),
- groups=1,
- dilation=_pair(1)):
+ return F.conv_transpose1d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1))
+
+
+@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
+def conv_transpose2d_impl(
+ input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
+):
if bias is None:
- return F.conv_transpose2d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose2d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose2d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1, 1))
-
-
-@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
-def conv_transpose3d_impl(input,
- weight,
- bias=None,
- stride=_triple(1),
- padding=_triple(0),
- output_padding=_triple(0),
- groups=1,
- dilation=_triple(1)):
+ return F.conv_transpose2d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1, 1))
+
+
+@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
+def conv_transpose3d_impl(
+ input,
+ weight,
+ bias=None,
+ stride=_triple(1),
+ padding=_triple(0),
+ output_padding=_triple(0),
+ groups=1,
+ dilation=_triple(1),
+):
if bias is None:
- return F.conv_transpose3d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose3d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose3d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1, 1, 1))
-
-
-@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
-@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
+ return F.conv_transpose3d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1, 1, 1))
+
+
+@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
+@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
return F.linear(mat1, mat2.transpose(0, 1)) + input
-@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
-@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
+@register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
+@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
index 112c7c9637d2..ff6b55be5117 100644
--- a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
+++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
@@ -4,6 +4,7 @@
try:
import apex
+
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py
index ce379efdcf0d..e3e210e7d190 100644
--- a/colossalai/_analyzer/fx/tracer/proxy.py
+++ b/colossalai/_analyzer/fx/tracer/proxy.py
@@ -1,10 +1,8 @@
import operator
-from typing import Any, Callable, Dict, Optional, Set, Union
+from typing import Any, Callable, Dict, Optional, Union
import torch
-import torch.nn as nn
-from torch.fx import Graph, Node, Proxy, Tracer
-from torch.fx.graph import _Namespace
+from torch.fx import Node, Proxy
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
@@ -32,7 +30,7 @@ def meta_data(self, args):
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
- impl = cls._func_dispatch.pop(orig_method) # avoid recursion
+ impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
@@ -72,7 +70,7 @@ def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
- proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
+ proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -89,7 +87,6 @@ def __isinstancecheck__(self, type):
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -102,11 +99,11 @@ def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
- self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+ self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+ return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
index 2018863f6f5f..7884fd911c86 100644
--- a/colossalai/_analyzer/fx/tracer/symbolic_trace.py
+++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Tracer
@@ -8,6 +8,7 @@
try:
from ..codegen import ActivationCheckpointCodeGen
+
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
@@ -16,7 +17,7 @@
def _default_device():
- return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def _current_device(module: torch.nn.Module):
@@ -144,10 +145,9 @@ def forward(self, x):
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
- graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
- bias_addition_split=bias_addition_split).trace(root.to(device),
- concrete_args=concrete_args,
- meta_args=tree_map(wrap_fn, meta_args))
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
+ root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
+ )
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py
index 6958a00a6a72..17dce767269d 100644
--- a/colossalai/_analyzer/fx/tracer/tracer.py
+++ b/colossalai/_analyzer/fx/tracer/tracer.py
@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
- return re.sub(r'_\d+$', '', s)
+ return re.sub(r"_\d+$", "", s)
-def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
-
+def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
@@ -34,7 +33,6 @@ def wrapper(impl):
def register_leaf_module_impl(module: nn.Module):
-
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
@@ -76,7 +74,7 @@ def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = Fal
self.ckpt_regions = []
self.ckpt_idx = 0
- self.mod_dir = ''
+ self.mod_dir = ""
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
@@ -87,35 +85,41 @@ def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
- return (type(m) not in self._custom_non_leaf_module
- and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
+ return type(m) not in self._custom_non_leaf_module and (
+ type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
+ )
- def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
- kwargs: Dict[str, Any]) -> Any:
+ def call_module(
+ self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
+ ) -> Any:
curr_dir = self.mod_dir
- self.mod_dir = 'self.' + self.path_of_module(m)
+ self.mod_dir = "self." + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
- def proxy(self, node: Node) -> 'ColoProxy':
+ def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
- def create_proxy(self,
- kind: str,
- target: Target,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
-
+ def create_proxy(
+ self,
+ kind: str,
+ target: Target,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None,
+ proxy_factory_fn: Callable[[Node], "Proxy"] = None,
+ ):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
- if kind == 'placeholder':
- proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
- _truncate_suffix(target), None)
- elif kind == 'get_attr':
+ if kind == "placeholder":
+ proxy.meta_data = (
+ self.meta_args[target]
+ if target in self.meta_args
+ else self.concrete_args.get(_truncate_suffix(target), None)
+ )
+ elif kind == "get_attr":
self.disable_module_getattr = True
try:
attr_itr = self.root
@@ -125,20 +129,21 @@ def create_proxy(self,
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
- elif kind == 'call_function':
+ elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
+ elif kind == "call_method":
self.disable_module_getattr = True
try:
- if target == '__call__':
+ if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
+ proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
finally:
self.disable_module_getattr = False
- elif kind == 'call_module':
+ elif kind == "call_module":
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
@@ -158,11 +163,12 @@ def create_node(self, *args, **kwargs) -> Node:
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
- def trace(self,
- root: torch.nn.Module,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None,
- meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
-
+ def trace(
+ self,
+ root: torch.nn.Module,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+ meta_args: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Graph:
if meta_args is None:
meta_args = {}
@@ -177,9 +183,7 @@ def trace(self,
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
- if k in sig_names - meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
@@ -194,9 +198,9 @@ def _check_arg_name_valid(names: Iterable[str]):
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
- self.mod_dir = 'self'
+ self.mod_dir = "self"
self.graph = super().trace(root, concrete_args=concrete_args)
- self.mod_dir = ''
+ self.mod_dir = ""
self.graph.lint()
for node in self.graph.nodes:
@@ -266,17 +270,17 @@ def _torch_factory_override(self):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
-
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
- isinstance(p, ColoProxy) for p in kwargs.values())
+ isinstance(p, ColoProxy) for p in kwargs.values()
+ )
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
- proxy = self.create_proxy('call_function', target, args, kwargs)
+ proxy = self.create_proxy("call_function", target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
@@ -341,10 +345,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
- if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
- kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
- lambda node: ColoProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ColoProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -355,8 +362,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py
index dc8499d877e1..34a20e8d67d6 100644
--- a/colossalai/amp/naive_amp/grad_scaler/__init__.py
+++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py
@@ -2,4 +2,4 @@
from .constant_grad_scaler import ConstantGradScaler
from .dynamic_grad_scaler import DynamicGradScaler
-__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler']
+__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"]
diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
index 0d84384a7f67..79661a44424f 100644
--- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
@@ -9,7 +9,7 @@
from colossalai.logging import get_dist_logger
-__all__ = ['BaseGradScaler']
+__all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC):
@@ -30,24 +30,21 @@ def __init__(self, initial_scale: float, verbose: bool):
@property
def scale(self) -> Tensor:
- """Returns the loss scale.
- """
+ """Returns the loss scale."""
return self._scale
@property
def inv_scale(self) -> Tensor:
- """Returns the inverse of the loss scale.
- """
+ """Returns the inverse of the loss scale."""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
- """Returns the states of the gradient scaler as a dict object.
- """
+ """Returns the states of the gradient scaler as a dict object."""
state_dict = dict()
- state_dict['scale'] = self.scale
+ state_dict["scale"] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
@@ -57,7 +54,7 @@ def load_state_dict(self, state_dict: Dict) -> None:
state_dict (dict): the states of the gradient scaler
"""
- self._scale = state_dict['scale']
+ self._scale = state_dict["scale"]
@abstractmethod
def update(self, overflow: bool) -> None:
@@ -67,8 +64,6 @@ def update(self, overflow: bool) -> None:
overflow (bool): whether overflow occurs
"""
- pass
-
def log(self, message, *args, **kwargs):
"""Log messages.
diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
index a2f518c5dd28..2ad8b51ac22c 100644
--- a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler
-__all__ = ['ConstantGradScaler']
+__all__ = ["ConstantGradScaler"]
class ConstantGradScaler(BaseGradScaler):
@@ -23,4 +23,3 @@ def update(self, overflow: bool) -> None:
Args:
overflow (bool): whether overflow occurs
"""
- pass
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index e899b9ca4c89..65133a4b3712 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -7,7 +7,7 @@
from .base_grad_scaler import BaseGradScaler
-__all__ = ['DynamicGradScaler']
+__all__ = ["DynamicGradScaler"]
class DynamicGradScaler(BaseGradScaler):
@@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler):
verbose (bool): whether to log messages, defaults to False
"""
- def __init__(self,
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- min_scale: Optional[float] = None,
- max_scale: Optional[float] = None,
- hysteresis: int = 2,
- verbose: bool = False):
+ def __init__(
+ self,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ min_scale: Optional[float] = None,
+ max_scale: Optional[float] = None,
+ hysteresis: int = 2,
+ verbose: bool = False,
+ ):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
@@ -53,18 +55,17 @@ def __init__(self,
self._sanity_checks()
def _sanity_checks(self) -> None:
- """Check if the arguments are correct.
- """
+ """Check if the arguments are correct."""
if self._min_scale:
- assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
- assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
+ assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
+ assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale"
if self._max_scale:
- assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
- assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
- assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
- assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
- assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
+ assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative"
+ assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale"
+ assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
+ assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1"
+ assert self._hysteresis >= 0, "The hysteresis cannot be negative"
def update(self, overflow: bool) -> None:
"""Update the loss scale.
@@ -88,19 +89,18 @@ def update(self, overflow: bool) -> None:
self.log(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}",
- ranks=[0])
+ ranks=[0],
+ )
def _backoff_scale(self) -> None:
- """Decrease the loss scale
- """
+ """Decrease the loss scale"""
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
- """Increase the loss scale
- """
+ """Increase the loss scale"""
self._scale = self._scale * self._growth_factor
if self._max_scale:
@@ -108,14 +108,14 @@ def _grow_scale(self) -> None:
def state_dict(self):
state_dict = dict()
- state_dict['scale'] = self._scale
- state_dict['growth_factor'] = self._growth_factor
- state_dict['backoff_factor'] = self._backoff_factor
- state_dict['hysteresis'] = self._hysteresis
+ state_dict["scale"] = self._scale
+ state_dict["growth_factor"] = self._growth_factor
+ state_dict["backoff_factor"] = self._backoff_factor
+ state_dict["hysteresis"] = self._hysteresis
return state_dict
def load_state_dict(self, state_dict):
- self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
- self._growth_factor = state_dict['growth_factor']
- self._backoff_factor = state_dict['backoff_factor']
- self._hysteresis = state_dict['hysteresis']
+ self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
+ self._growth_factor = state_dict["growth_factor"]
+ self._backoff_factor = state_dict["backoff_factor"]
+ self._hysteresis = state_dict["hysteresis"]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
index b0348e1477bb..a31811e4a567 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
@@ -3,7 +3,7 @@
from .fp16 import FP16MixedPrecisionMixin
__all__ = [
- 'MixedPrecisionMixin',
- 'FP16MixedPrecisionMixin',
- 'BF16MixedPrecisionMixin',
+ "MixedPrecisionMixin",
+ "FP16MixedPrecisionMixin",
+ "BF16MixedPrecisionMixin",
]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
index a52a9747ad1e..fc7e0b74179a 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
@@ -39,6 +39,7 @@ def zero_grad(self):
return self.optim.zero_grad()
```
"""
+
dtype: torch.dtype
@abstractmethod
@@ -51,7 +52,6 @@ def pre_backward(self, loss: Tensor) -> Tensor:
Returns:
Tensor: Loss value (possibly scaled).
"""
- pass
@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
@@ -64,7 +64,6 @@ def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
- pass
@abstractmethod
def should_skip_step(self) -> bool:
@@ -73,13 +72,10 @@ def should_skip_step(self) -> bool:
Returns:
bool: Whether to skip the step.
"""
- pass
@abstractmethod
def pre_zero_grad(self) -> None:
- """Called before zero_grad.
- """
- pass
+ """Called before zero_grad."""
@abstractmethod
def get_grad_div_scale(self) -> float:
@@ -88,4 +84,3 @@ def get_grad_div_scale(self) -> float:
Returns:
float: A divisor for gradient clipping or step.
"""
- pass
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
index 1ce8e42eb3ed..9ce272356797 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
@@ -19,22 +19,26 @@ class OptimState(Enum):
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16
- def __init__(self,
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32) -> None:
+ def __init__(
+ self,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ ) -> None:
super().__init__()
- self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
- min_scale=min_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- max_scale=max_scale)
+ self.grad_scaler = DynamicGradScaler(
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
@@ -49,7 +53,6 @@ def check_local_overflow(self) -> bool:
Returns:
bool: Whether there is overflow in the local process.
"""
- pass
def check_overflow(self) -> bool:
# clear previous overflow record
@@ -79,6 +82,6 @@ def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
- assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
+ assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping"
self.optim_state = OptimState.UNSCALED
return self.loss_scale
diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
index 626a00c96d04..6a192cc5cb83 100644
--- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py
+++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
@@ -11,18 +11,20 @@
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
-
- def __init__(self,
- working_params: List[Parameter],
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32) -> None:
- super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
- max_scale)
+ def __init__(
+ self,
+ working_params: List[Parameter],
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ ) -> None:
+ super().__init__(
+ initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
+ )
self.params = working_params
def check_local_overflow(self) -> bool:
@@ -33,38 +35,41 @@ def check_local_overflow(self) -> bool:
class MixedPrecisionOptimizer(OptimizerWrapper):
-
- def __init__(self,
- optim: Optimizer,
- precision: str = 'fp16',
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0.0):
+ def __init__(
+ self,
+ optim: Optimizer,
+ precision: str = "fp16",
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ ):
super().__init__(optim)
- if precision == 'fp16':
+ if precision == "fp16":
working_params = []
for group in self.optim.param_groups:
- for p in group['params']:
+ for p in group["params"]:
working_params.append(p)
- self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
- initial_scale=initial_scale,
- min_scale=min_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- max_scale=max_scale)
- elif precision == 'bf16':
+ self.mixed_precision = NaiveFP16MixedPrecisionMixin(
+ working_params,
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
+ elif precision == "bf16":
self.mixed_precision = BF16MixedPrecisionMixin()
else:
- raise ValueError(f'Unsupported precision: {precision}')
+ raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
- raise NotImplementedError('max_norm is not supported yet.')
+ raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
@@ -72,7 +77,7 @@ def __init__(self,
# create master weights
for group in self.optim.param_groups:
master_params = []
- for p in group['params']:
+ for p in group["params"]:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
@@ -80,7 +85,7 @@ def __init__(self,
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
- group['params'] = master_params
+ group["params"] = master_params
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
@@ -101,24 +106,24 @@ def _unscale_and_clip_grads(self, total_norm: float) -> None:
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()
- if self.max_norm > 0.:
+ if self.max_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
for group in self.param_groups:
- for p in group['params']:
+ for p in group["params"]:
if p.grad is None:
continue
- p.grad.data.mul_(1. / div_scale)
+ p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float:
- if self.max_norm <= 0.:
- return 0.
- grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
+ if self.max_norm <= 0.0:
+ return 0.0
+ grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
- return 0.
+ return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
@@ -130,7 +135,7 @@ def step(self, *args, **kwargs):
return
# prepare grads
for group in self.optim.param_groups:
- for p in group['params']:
+ for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
@@ -142,7 +147,7 @@ def step(self, *args, **kwargs):
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
- for p in group['params']:
+ for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py
index af4349865a7b..7de56f80525a 100644
--- a/colossalai/auto_parallel/checkpoint/build_c_ext.py
+++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py
@@ -3,14 +3,16 @@
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
-ext_modules = [Extension(
- 'rotorc',
- sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
-)]
+ext_modules = [
+ Extension(
+ "rotorc",
+ sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")],
+ )
+]
setup(
- name='rotor c extension',
- version='0.1',
- description='rotor c extension for faster dp computing',
+ name="rotor c extension",
+ version="0.1",
+ description="rotor c extension for faster dp computing",
ext_modules=ext_modules,
)
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
index b388d00ac553..8aaa690b333c 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
@@ -12,13 +12,13 @@
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
-__all___ = ['CheckpointSolverBase']
+__all___ = ["CheckpointSolverBase"]
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
- if n_src.op == 'output':
+ if n_src.op == "output":
n_dst.meta = n_src.meta
@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
class CheckpointSolverBase(ABC):
-
def __init__(
self,
graph: Graph,
@@ -81,13 +80,10 @@ def __init__(
@abstractmethod
def solve(self):
- """Solve the checkpointing problem and return the solution.
- """
- pass
+ """Solve the checkpointing problem and return the solution."""
def get_node_list(self):
- """Get the node list.
- """
+ """Get the node list."""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
@@ -140,8 +136,7 @@ def _is_sink() -> bool:
"""
def _is_inplace(n: Node):
- """Get the inplace argument from ``torch.fx.Node``
- """
+ """Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
@@ -150,19 +145,22 @@ def _is_inplace(n: Node):
return inplace
def _is_shape_consistency(n: Node):
- """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
- """
+ """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
- return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
- map(_is_shape_consistency, n.users))
+ return (
+ not sum([v for _, v in deps.items()])
+ and not any(map(_is_inplace, n.users))
+ and not any(map(_is_shape_consistency, n.users))
+ )
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
- assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
- f"Common node {name} is not an input of the model."
+ assert (
+ next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
+ ), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
@@ -187,8 +185,9 @@ def _is_shape_consistency(n: Node):
region = []
# propagate common node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
- ]) or _is_cop(n.target):
+ if len(n.all_input_nodes) == len(
+ [node for node in n.all_input_nodes if node.name in self.cnode]
+ ) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
index 19b2ef5987c9..ab16cc04b730 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
@@ -8,11 +8,10 @@
from .ckpt_solver_base import CheckpointSolverBase
-__all__ = ['CheckpointSolverChen']
+__all__ = ["CheckpointSolverChen"]
class CheckpointSolverChen(CheckpointSolverBase):
-
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
@@ -40,14 +39,14 @@ def solve(self) -> Graph:
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
- checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
+ checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
- n.meta['activation_checkpoint'] = i
+ n.meta["activation_checkpoint"] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
index 21c3bf0da758..d10c41ae2b96 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
@@ -1,5 +1,5 @@
from copy import deepcopy
-from typing import Any, Dict, List, Tuple
+from typing import Any, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
@@ -18,17 +18,18 @@
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
-__all__ = ['CheckpointSolverRotor']
+__all__ = ["CheckpointSolverRotor"]
class CheckpointSolverRotor(CheckpointSolverBase):
-
- def __init__(self,
- graph: Graph,
- free_memory: float = -1,
- cnode: List[str] = None,
- memory_slots: int = 500,
- optim_multiplier: float = 1.0):
+ def __init__(
+ self,
+ graph: Graph,
+ free_memory: float = -1,
+ cnode: List[str] = None,
+ memory_slots: int = 500,
+ optim_multiplier: float = 1.0,
+ ):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
@@ -85,13 +86,14 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
# backtrack
try:
- self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
- self.back_ptr)
+ self.sequence = self._backtrack(
+ chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr
+ )
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
- logger.warning(f'Checkpoint solver failed: {e}')
+ logger.warning(f"Checkpoint solver failed: {e}")
raise ValueError
if verbose:
@@ -100,14 +102,19 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
return deepcopy(self.graph)
def print_chain(self):
- print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
+ print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
- print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
- self.chain.btmp[idx])
- print(f'Chain = {self.chain}')
+ print(
+ self.node_list[idx],
+ self.chain.x[idx + 1],
+ self.chain.xbar[idx + 1],
+ self.chain.ftmp[idx],
+ self.chain.btmp[idx],
+ )
+ print(f"Chain = {self.chain}")
def print_sequence(self):
- print(f'Sequence = {self.sequence}')
+ print(f"Sequence = {self.sequence}")
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
@@ -138,14 +145,14 @@ def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
btime = 0
fwd_mem_peak = 0
for n in node:
- assert isinstance(n, Node), f'{n} is not a Node'
+ assert isinstance(n, Node), f"{n} is not a Node"
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
- xbar += n.meta['fwd_mem_out']
- fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
+ xbar += n.meta["fwd_mem_out"]
+ fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
- fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
+ fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
@@ -162,14 +169,14 @@ def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
- if node.op == 'placeholder':
- input_tensors.append(node.meta['fwd_out'])
+ if node.op == "placeholder":
+ input_tensors.append(node.meta["fwd_out"])
return input_tensors
@staticmethod
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
- return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
+ return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
@@ -180,8 +187,8 @@ def _extract_deps_size():
for k, v in deps.items():
k: Node
if v > 0:
- deps_size += k.meta['bwd_mem_out']
- if v == float('-inf'):
+ deps_size += k.meta["bwd_mem_out"]
+ if v == float("-inf"):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
@@ -190,12 +197,12 @@ def _extract_deps_size():
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
- btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
+ btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
- deps[child] = float('-inf') # free
+ deps[child] = float("-inf") # free
return btmp
@staticmethod
@@ -244,10 +251,11 @@ def _compute_table(chain: Chain, mmax: int) -> Tuple:
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
- leaf_checkpoints = [(j,
- sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
- for j in range(i + 1, idx + 1)
- if m >= x[j]]
+ leaf_checkpoints = [
+ (j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
+ for j in range(i + 1, idx + 1)
+ if m >= x[j]
+ ]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
@@ -274,13 +282,16 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
import os
import subprocess
import sys
+
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
- f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
- f"--build-lib={this_dir}"
+ f"{sys.executable}",
+ f"{os.path.join(this_dir, 'build_c_ext.py')}",
+ "build_ext",
+ f"--build-lib={this_dir}",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
@@ -294,8 +305,9 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
return compute_table(chain, mmax)
@staticmethod
- def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
- back_ptr: List[Any]) -> "Sequence":
+ def _backtrack(
+ chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any]
+ ) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
@@ -328,8 +340,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
- CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
- back_ptr),
+ CheckpointSolverRotor._backtrack(
+ chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr
+ ),
Backward(lhs),
]
else:
@@ -337,8 +350,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
- CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
- back_ptr),
+ CheckpointSolverRotor._backtrack(
+ chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr
+ ),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
@@ -353,8 +367,8 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
- fwd_list = op_list[:op_list.index(loss_op)]
- bwd_list = op_list[op_list.index(loss_op) + 1:]
+ fwd_list = op_list[: op_list.index(loss_op)]
+ bwd_list = op_list[op_list.index(loss_op) + 1 :]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
@@ -369,7 +383,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'] = [ckpt_idx]
+ n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
@@ -377,7 +391,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'] = [ckpt_idx]
+ n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
@@ -397,7 +411,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
@@ -405,7 +419,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
@@ -413,7 +427,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
in_recompute = False
@@ -431,9 +445,11 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
- for (start_idx, end_idx) in ckpt_regions:
+ for start_idx, end_idx in ckpt_regions:
nested_length = max(
- len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
+ len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1)
+ )
for idx in range(start_idx, end_idx + 1):
- op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
- len(op_list[idx].meta['activation_checkpoint']))
+ op_list[idx].meta["activation_checkpoint"] += [None] * (
+ nested_length - len(op_list[idx].meta["activation_checkpoint"])
+ )
diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py
index ab0c6c5ad38d..5f8077916433 100644
--- a/colossalai/auto_parallel/checkpoint/operation.py
+++ b/colossalai/auto_parallel/checkpoint/operation.py
@@ -1,20 +1,21 @@
import math
from abc import ABC
-from typing import Any, Iterable, List
+from typing import List
from torch.utils._pytree import tree_map
class Chain:
-
- def __init__(self,
- ftime: List[float],
- btime: List[float],
- x: List[int],
- xbar: List[int],
- ftmp: List[int],
- btmp: List[int],
- check_consistency: bool = True):
+ def __init__(
+ self,
+ ftime: List[float],
+ btime: List[float],
+ x: List[int],
+ xbar: List[int],
+ ftmp: List[int],
+ btmp: List[int],
+ check_consistency: bool = True,
+ ):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
@@ -37,9 +38,14 @@ def __init__(self,
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
- return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
- and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
- and (len(self.xbar) == len(self) + 1))
+ return (
+ (len(self.ftime) == len(self))
+ and (len(self.btime) == len(self) + 1)
+ and (len(self.x) == len(self) + 1)
+ and (len(self.ftmp) == len(self))
+ and (len(self.btmp) == len(self) + 1)
+ and (len(self.xbar) == len(self) + 1)
+ )
def __repr__(self):
chain_list = []
@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
class Forwards(Operation):
-
def __init__(self, start, end):
self.index = (start, end)
@@ -109,9 +114,9 @@ def __repr__(self):
def cost(self, chain: Chain):
if chain is not None:
- return sum(chain.ftime[self.index[0]:self.index[1] + 1])
+ return sum(chain.ftime[self.index[0] : self.index[1] + 1])
else:
- return (self.index[1] - self.index[0] + 1)
+ return self.index[1] - self.index[0] + 1
def isForward(op):
@@ -132,7 +137,6 @@ def cost(self, chain: Chain):
class Loss(Operation):
-
def __init__(self):
pass
@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
class Sequence(list):
-
def __init__(self):
super().__init__()
diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py
index 35b8c13ee8ff..2f638fa919e4 100644
--- a/colossalai/auto_parallel/meta_profiler/constants.py
+++ b/colossalai/auto_parallel/meta_profiler/constants.py
@@ -3,8 +3,6 @@
import torch
import torch.nn as nn
-from ..tensor_shard.constants import *
-
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
index 0f2e9e44f91c..4234481ae2ca 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next(
filter(
- lambda x:
- (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
- args)).data
+ lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
+ and x.name != "softmax_dim",
+ args,
+ )
+ ).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
- is_inplace = 1 if kwargs.get('inplace', False) else 0
+ is_inplace = 1 if kwargs.get("inplace", False) else 0
flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward
- fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
- parameter=0,
- temp=0,
- buffer=activation_size(input_tensor) * buffer_mem_scale)
+ fwd_memory_cost = MemoryCost(
+ activation=activation_size(input_tensor) * (2 - is_inplace),
+ parameter=0,
+ temp=0,
+ buffer=activation_size(input_tensor) * buffer_mem_scale,
+ )
# temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase
@@ -54,20 +58,23 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
- buffer=0)
+ buffer=0,
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
- temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
- buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
- fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
index e451748512b9..0b7b51a71955 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
@@ -6,10 +6,10 @@
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
-from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
+from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
-__all__ = ['binary_elementwise_meta_info']
+__all__ = ["binary_elementwise_meta_info"]
@meta_register.register(BCAST_FUNC_OP)
@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
+ fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
index 4336bf68363c..2f630995cdbc 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
@@ -1,22 +1,14 @@
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
-__all__ = ['convnd_meta_info']
+__all__ = ["convnd_meta_info"]
@meta_register.register(torch.nn.Conv1d)
@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
- flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
+ bwd_compute_cost = (
+ flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
+ if has_bias
+ else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
+ )
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
-
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
+
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes([input_tensor, weight_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
index d5d80f5b3700..7c9add810fd8 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
- bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
- [weight_tensor])
+ bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
+ [output_tensor, weight_tensor], [weight_tensor]
+ )
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=0,
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
+ )
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
index 94dd9143e0ae..d731f9cb4436 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
@@ -1,23 +1,15 @@
from functools import reduce
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
-__all__ = ['linear_meta_info', 'matmul_meta_info']
+__all__ = ["linear_meta_info", "matmul_meta_info"]
@meta_register.register(torch.nn.functional.linear)
@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
- [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
- flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
- flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
+ )
+ bwd_compute_cost = (
+ flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
+ + flop_mapping[torch.ops.aten.mm.default](
+ [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
+ )
+ + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
+ )
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=0,
+ )
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=0)
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=0,
+ )
# total cost is to sum the forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
- flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
+ [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
+ )
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [output_tensor, weight_tensor], (input_tensor,)
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
+ )
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
- parameter=compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor]),
+ parameter=compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# total cost is to sum the forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
+ )
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
- [output_tensors[0].reshape(-1), input_tensors[1]],
- output_tensors) + \
- flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
- output_tensors)
+ [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default](
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
+ output_tensors,
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
- bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
- flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
+ [output_tensors[0], input_tensors[0]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]),
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]),
+ buffer=0,
+ )
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
- [output_tensors[0].reshape(-1)])
+ [output_tensors[0].reshape(-1)],
+ )
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
- [output_tensors[0].reshape(-1), input_tensors[0]],
- output_tensors
- ) + \
- flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
- output_tensors
- )
+ [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default](
+ [
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
+ output_tensors[0].reshape(-1),
+ ],
+ output_tensors,
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]),
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors[0]),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]),
+ buffer=0,
+ )
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
- [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
+ )
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
- [input_tensors[1]]
- ) + \
- flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
- )
+ [
+ input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
+ output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
+ ],
+ [input_tensors[1]],
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
- fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
- input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
- 0, 1)
- ], [output_tensors[0].transpose(-2, -1)])
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
+ input_tensors[0].transpose(0, 1),
+ ],
+ [output_tensors[0].transpose(-2, -1)],
+ )
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
- [input_tensors[0]]
- ) + \
- flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
- [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
- )
-
- fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
- compute_size_in_bytes(input_tensors[1]),
- temp=compute_size_in_bytes(output_tensors))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
+ [
+ output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
+ ],
+ [input_tensors[0]],
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
+ [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
+ )
+
+ fwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
+ temp=compute_size_in_bytes(output_tensors),
+ )
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors[0]),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
+ )
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
- for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
+ for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
- fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
- input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
- -1, input_dim_10, input_dim_11)
- ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+ fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
+ [
+ input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
+ input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
+ ],
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ )
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
- [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
- ) + \
- flop_mapping[torch.ops.aten.bmm.default](
- [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
- [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
- )
+ [
+ input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
+ output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
+ ],
+ [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
+ ) + flop_mapping[torch.ops.aten.bmm.default](
+ [
+ output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
+ ],
+ [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
- extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
- input_dim_00,
- input_dim_01,
- device="meta")
- extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
- input_dim_10,
- input_dim_11,
- device="meta")
+ extended_input_0 = torch.rand(
+ reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
+ )
+ extended_input_1 = torch.rand(
+ reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
+ )
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+ [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
+ )
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
- [extended_input_1]
- ) + \
- flop_mapping[torch.ops.aten.bmm.default](
- [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
- [extended_input_0]
- )
+ [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ [extended_input_1],
+ ) + flop_mapping[torch.ops.aten.bmm.default](
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
+ [extended_input_0],
+ )
fwd_mem_cost = MemoryCost(
- activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
- compute_size_in_bytes([extended_input_0, extended_input_1]),
- temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
+ activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
+ )
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors)
+ - compute_size_in_bytes([extended_input_0, extended_input_1]),
+ temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
+ )
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
- total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ total_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
index 12874810b13e..b1bb1d872c35 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
@@ -3,7 +3,7 @@
import torch
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
index b872fdc8bdcd..99aaa752d0a1 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
@@ -1,22 +1,14 @@
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
-__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
+__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"]
@meta_register.register(torch.nn.BatchNorm1d)
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
- output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
+ output_tensor,
+ output_tensor,
+ weight_tensor,
+ mean_tensor,
+ var_tensor,
+ mean_tensor,
+ var_tensor,
+ 1e-5,
+ num_batch,
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
- [input_tensor, output_tensor, mean_tensor, var_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
+ )
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=compute_size_in_bytes([mean_tensor, var_tensor]),
- buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=compute_size_in_bytes([mean_tensor, var_tensor]),
+ buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
- running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
- running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
+ running_mean = torch.rand(input_tensor.shape[0], 1, device="meta")
+ running_var = torch.rand(input_tensor.shape[0], 1, device="meta")
# construct args
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
- [input_tensor, output_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=compute_size_in_bytes([running_mean, running_var]))
-
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=compute_size_in_bytes([running_mean, running_var]),
- buffer=compute_size_in_bytes([running_mean, running_var]))
-
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
- temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
- buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=compute_size_in_bytes([running_mean, running_var]),
+ )
+
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=compute_size_in_bytes([running_mean, running_var]),
+ buffer=compute_size_in_bytes([running_mean, running_var]),
+ )
+
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
index d785dfcca9ba..21aa524bed08 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
- temp=compute_size_in_bytes(index_matrix))
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
+ temp=compute_size_in_bytes(index_matrix),
+ )
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(index_matrix, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
index 97fe3c6196f5..9a2df1bd7c87 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
@@ -2,7 +2,6 @@
import torch
-from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
@@ -37,15 +36,19 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
- parameter=0,
- temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
+ parameter=0,
+ temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
+ buffer=0,
+ )
- total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ total_mem_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
@@ -66,14 +69,24 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor
# register torch.Tensor related metainfo
# (0, 0)
-meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
- torch.arange])(tensor_related_metainfo(0, 0))
+meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])(
+ tensor_related_metainfo(0, 0)
+)
# (1, 0)
-meta_register.register([
- torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
- torch.Tensor.split, torch.split, torch.Tensor.view
-])(tensor_related_metainfo(1, 0))
+meta_register.register(
+ [
+ torch.Tensor.flatten,
+ torch.flatten,
+ torch.Tensor.transpose,
+ torch.transpose,
+ torch.Tensor.permute,
+ torch.permute,
+ torch.Tensor.split,
+ torch.split,
+ torch.Tensor.view,
+ ]
+)(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
index 5cba1b5b6e2b..107851b80d7c 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
@@ -4,7 +4,7 @@
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
- bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
- parameter=0,
- temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
- activation_size([x_tensor, y_tensor]),
- buffer=0)
-
- total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ bwd_mem_cost = MemoryCost(
+ activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
+ parameter=0,
+ temp=activation_size([output_tensor]) * 3
+ + activation_size([condition_tensor])
+ - activation_size([x_tensor, y_tensor]),
+ buffer=0,
+ )
+
+ total_mem_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py
index 46350c4dd406..c29086f7f9d1 100644
--- a/colossalai/auto_parallel/meta_profiler/registry.py
+++ b/colossalai/auto_parallel/meta_profiler/registry.py
@@ -1,14 +1,12 @@
-__all__ = ['Registry']
+__all__ = ["Registry"]
class Registry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
@@ -21,7 +19,7 @@ def wrapper(func):
return wrapper
def get(self, source):
- assert source in self.store, f'{source} not found in the {self.name} registry'
+ assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
@@ -29,4 +27,4 @@ def has(self, source):
return source in self.store
-meta_register = Registry('meta')
+meta_register = Registry("meta")
diff --git a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
index 0eee908b48b7..109b8a220ac7 100644
--- a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
+++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
@@ -2,20 +2,13 @@
import torch
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
-__all__ = ['ShardMetaInfo']
+__all__ = ["ShardMetaInfo"]
class ShardMetaInfo:
@@ -76,10 +69,12 @@ def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: S
"""
if isinstance(sharding_spec, ShardingSpec):
- op_data = OperationData(name=operation_data.name,
- data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
- type=operation_data.type,
- logical_shape=operation_data.logical_shape)
+ op_data = OperationData(
+ name=operation_data.name,
+ data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
+ type=operation_data.type,
+ logical_shape=operation_data.logical_shape,
+ )
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
@@ -97,8 +92,9 @@ def compute_shard_metainfo(self):
"""
Compute meta info based on sharding strategy and the given target function.
"""
- assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
- f"Meta info for {self._target} is not registered."
+ assert meta_register.has(self._target.__class__) or meta_register.has(
+ self._target
+ ), f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
@@ -117,11 +113,11 @@ def compute_shard_metainfo(self):
# construct kwargs
if self.target in INPLACE_MODULE:
- kwargs = {'inplace': self.target.inplace}
+ kwargs = {"inplace": self.target.inplace}
elif self.target in INPLACE_OPS:
- kwargs = {'inplace': True}
+ kwargs = {"inplace": True}
else:
- kwargs = {'inplace': False}
+ kwargs = {"inplace": False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
index 353133bd6f2d..601bf2926d99 100644
--- a/colossalai/auto_parallel/offload/amp_optimizer.py
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -37,19 +37,20 @@ class AMPOptimizer(OptimizerWrapper):
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
- def __init__(self,
- optimizer: Optimizer,
- module: BaseOffloadModule,
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- clipping_norm: float = 0.0,
- norm_type: float = 2.0):
-
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ module: BaseOffloadModule,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ clipping_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ):
super().__init__(optimizer)
self.module = module
@@ -69,19 +70,21 @@ def __init__(self,
self.__init__optimizer()
# Grad scaler
- self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
- min_scale=min_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- max_scale=max_scale)
+ self.grad_scaler = DynamicGradScaler(
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()
def _set_grad_ptr(self):
for group in self.param_groups:
- for fake_param in group['params']:
+ for fake_param in group["params"]:
region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param]
@@ -92,7 +95,7 @@ def _set_grad_ptr(self):
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
- for fake_param in group['params']:
+ for fake_param in group["params"]:
assert fake_param.grad is None
fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None
@@ -130,10 +133,10 @@ def step(self, *args, **kwargs):
found_inf = self._check_overflow()
if found_inf:
- self.optim_state = OptimState.UNSCALED # no need to unscale grad
- self.grad_scaler.update(found_inf) # update gradient scaler
- self._logger.info(f'Found overflow. Skip step')
- self.zero_grad() # reset all gradients
+ self.optim_state = OptimState.UNSCALED # no need to unscale grad
+ self.grad_scaler.update(found_inf) # update gradient scaler
+ self._logger.info(f"Found overflow. Skip step")
+ self.zero_grad() # reset all gradients
self._update_fp16_params()
return
@@ -156,11 +159,10 @@ def backward(self, loss: torch.Tensor):
self.module.backward(loss)
def __init__optimizer(self):
-
for group in self.optim.param_groups:
fake_params_list = list()
- for param in group['params']:
+ for param in group["params"]:
region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param]
@@ -171,7 +173,7 @@ def __init__optimizer(self):
if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param)
- group['params'] = fake_params_list
+ group["params"] = fake_params_list
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py
index 5b9f74b132f3..f5e8e31f5e97 100644
--- a/colossalai/auto_parallel/offload/base_offload_module.py
+++ b/colossalai/auto_parallel/offload/base_offload_module.py
@@ -22,7 +22,6 @@ class BaseOffloadModule:
"""
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
-
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
@@ -91,17 +90,16 @@ def _cast_buffers(self):
def parameters(self, recurse: bool = True):
return self.model.parameters(recurse)
- def named_parameters(self, prefix: str = '', recurse: bool = True):
+ def named_parameters(self, prefix: str = "", recurse: bool = True):
return self.model.named_parameters(prefix, recurse)
- def named_buffers(self, prefix: str = '', recurse: bool = True):
+ def named_buffers(self, prefix: str = "", recurse: bool = True):
return self.model.named_buffers(prefix, recurse)
def named_children(self):
return self.model.named_children()
- def named_modules(self,
- memo: Optional[Set[torch.nn.Module]] = None,
- prefix: str = '',
- remove_duplicate: bool = True):
+ def named_modules(
+ self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
+ ):
return self.model.named_modules(memo, prefix, remove_duplicate)
diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py
index d56166dea982..74501c184518 100644
--- a/colossalai/auto_parallel/offload/mem_optimize.py
+++ b/colossalai/auto_parallel/offload/mem_optimize.py
@@ -14,11 +14,9 @@
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
-def memory_optimize(model: torch.nn.Module,
- inps: Dict[str, torch.Tensor],
- memory_budget: float = -1.0,
- solver_name: str = 'asyn'):
-
+def memory_optimize(
+ model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn"
+):
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)
- if solver_name == 'syn':
+ if solver_name == "syn":
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
- elif solver_name == 'asyn':
+ elif solver_name == "asyn":
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else:
raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile()
- optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
+ optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn")
return optimized_model
diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py
index 819ffbd96eb1..ea92c714ce31 100644
--- a/colossalai/auto_parallel/offload/region.py
+++ b/colossalai/auto_parallel/offload/region.py
@@ -55,13 +55,13 @@ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
Map the parameters in the region to a contiguous memory space.
"""
- self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
+ self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda")
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
- self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
- param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
+ self.fp16_data[offset : offset + p_num].copy_(param.data.flatten())
+ param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
@@ -83,7 +83,7 @@ def move_param_to_cuda(self):
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data)
- self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
+ self.fp16_data[: self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr()
@@ -94,7 +94,7 @@ def move_grad_to_cpu(self):
"""
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
- self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
+ self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
self.free_cuda_data()
diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py
index 30bfaf00d493..146dd267967d 100644
--- a/colossalai/auto_parallel/offload/region_manager.py
+++ b/colossalai/auto_parallel/offload/region_manager.py
@@ -1,10 +1,11 @@
-from typing import List, Any, Dict, Tuple
+from typing import Any, Dict, List, Tuple
+
import torch
from torch.fx import Graph, Node
+from .region import Region
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
-from .region import Region
from .util import NodeInfo
@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
"""
- def __init__(self,
- graph: Graph,
- solver_name: str = 'asyn',
- memory_budget: float = -1.0,
- cnode: List[str] = None):
-
+ def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
self.graph = graph
- assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
+ assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
@@ -39,7 +35,7 @@ def __init__(self,
self.memory_budget = memory_budget
self.solver_name = solver_name
- self.require_pool: bool = solver_name == 'asyn'
+ self.require_pool: bool = solver_name == "asyn"
self.reg_to_block: Dict[int, int] = dict()
@@ -61,22 +57,19 @@ def _build_regions(self):
self._post_process(solver.best_ts)
def _pre_process(self):
-
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
- raise NotImplementedError(
- 'The current version only considers at most one pair of parameter sharing.')
+ raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
- assert shared_regs[0].shared_rid == shared_regs[1].r_id \
- and shared_regs[1].shared_rid == shared_regs[0].r_id
+ assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
- regs_left_out = init_region_list[:fst_id + 1]
+ regs_left_out = init_region_list[: fst_id + 1]
regs_right_out = init_region_list[lst_id:]
- hold_regs = init_region_list[fst_id + 1:lst_id]
+ hold_regs = init_region_list[fst_id + 1 : lst_id]
else:
regs_left_out = []
regs_right_out = []
@@ -122,12 +115,9 @@ def _early_region_placement(self, ts: TrainingSimulator):
it may not find a suitable region placement strategy for the given execution flow.
"""
- reg_flow = torch.cat(
- [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
- mem_block_num = torch.max(
- torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
- coexist_matrix = torch.logical_or(
- ts.fwd_reg_flow, ts.bwd_reg_flow)
+ reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
+ mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
+ coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
@@ -135,8 +125,7 @@ def _early_region_placement(self, ts: TrainingSimulator):
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
- cur_reg_coexists = torch.sum(
- coexist_matrix[cur_reg_appears], dim=0).bool()
+ cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
@@ -145,9 +134,12 @@ def _early_region_placement(self, ts: TrainingSimulator):
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
- f'can not find a block from the memory pool to store parameters of the region')
- self.memory_pool = torch.chunk(torch.zeros(int(
- mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
+ f"can not find a block from the memory pool to store parameters of the region"
+ )
+ self.memory_pool = torch.chunk(
+ torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
+ chunks=int(mem_block_num),
+ )
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
@@ -178,10 +170,9 @@ def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
return region_list
- def _search_block_size(self,
- region_list: List[Region],
- search_interval_byte: int = 1024,
- search_range_byte: int = 128 * 1024 ** 2) -> int:
+ def _search_block_size(
+ self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
+ ) -> int:
"""
Search for a suitable memory block size.
@@ -208,11 +199,10 @@ def _get_wasted_mem(size_list: List[int], blk_size: int):
acc_wasted += blk_size - left
return acc_wasted
- param_size_list = [
- region.param_size for region in region_list if region.r_id == region.shared_rid]
+ param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
- min_mem_waste = float('+inf')
+ min_mem_waste = float("+inf")
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
@@ -229,7 +219,7 @@ def _init_region_data(self):
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
- self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
+ self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
@@ -244,8 +234,7 @@ def _init_region_data(self):
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
- region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
- )
+ region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
torch.cuda.empty_cache()
@@ -259,13 +248,14 @@ def _process_shared_region(self):
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
- assert embedding_node.op == 'call_module' and isinstance(
- self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
+ assert embedding_node.op == "call_module" and isinstance(
+ self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
+ )
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
- if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
- torch.nn.Linear)) or \
- (n.op == 'call_function' and n.target is torch.nn.functional.linear):
+ if (
+ n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
+ ) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
@@ -273,7 +263,7 @@ def _process_shared_region(self):
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
- for reg in self.region_list[new_reg.r_id + 1:]:
+ for reg in self.region_list[new_reg.r_id + 1 :]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
@@ -344,8 +334,8 @@ def _maybe_param_comp_start() -> bool:
target = n.target
submod = self.root_module.get_submodule(target)
if (
- len(list(submod.named_parameters(recurse=False))) != 0
- or len(list(submod.named_buffers(recurse=False))) != 0
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
@@ -362,14 +352,12 @@ def _is_param_comp_end() -> bool:
"""
def _is_inplace(n: Node):
- """Get the inplace argument from ``torch.fx.Node``
- """
+ """Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
- inplace = getattr(n.graph.owning_module.get_submodule(
- n.target), "inplace", False)
+ inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
label = False
@@ -378,28 +366,30 @@ def _is_inplace(n: Node):
target = n.target
submod = self.root_module.get_submodule(target)
if (
- len(list(submod.named_parameters(recurse=False))) != 0
- or len(list(submod.named_buffers(recurse=False))) != 0
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
- map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
+ map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
+ )
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
- if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
- n.meta['fwd_out'] = []
+ if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
+ n.meta["fwd_out"] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
- assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
- f"Common node {name} is not an input of the model."
+ assert (
+ next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
+ ), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
@@ -428,8 +418,8 @@ def _exception_node_handling():
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
- ns = region.nodes[border_n_idx + 1:]
- region.nodes = region.nodes[:border_n_idx + 1]
+ ns = region.nodes[border_n_idx + 1 :]
+ region.nodes = region.nodes[: border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
@@ -448,19 +438,21 @@ def _exception_node_handling():
region = Region(r_id=region_id)
# propagate common node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
- ]) or _is_cop(n.target):
+ if len(n.all_input_nodes) == len(
+ [node for node in n.all_input_nodes if node.name in self.cnode]
+ ) or _is_cop(n.target):
self.cnode.append(n.name)
else:
- deps[n] = len(
- [user for user in n.users if user.op != "output"])
+ deps[n] = len([user for user in n.users if user.op != "output"])
# propagate param node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
- ]) or n.op == "get_attr":
+ if (
+ len(n.all_input_nodes)
+ == len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
+ or n.op == "get_attr"
+ ):
self.only_param_ops.append(n.name)
- param_op_deps[n] = len(
- [user for user in n.users if user.op != "output"])
+ param_op_deps[n] = len([user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
@@ -472,19 +464,16 @@ def _exception_node_handling():
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
-
cur_n.node_info = NodeInfo(node_id)
- if cur_n.op == 'call_module':
+ if cur_n.op == "call_module":
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
-
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
- self.shared_region_pairs.append(
- (self.param_region_map[p], cur_reg))
+ self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
@@ -499,12 +488,10 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
-
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
- self.shared_region_pairs.append(
- (self.param_region_map[attr_itr], cur_reg))
+ self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg
diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py
index 764ac608826b..cc790dfb0891 100644
--- a/colossalai/auto_parallel/offload/runtime.py
+++ b/colossalai/auto_parallel/offload/runtime.py
@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
- d2h_rid = fwd_info.get('d2h_rid', None)
+ d2h_rid = fwd_info.get("d2h_rid", None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
- h2d_rid = fwd_info.get('h2d_rid', None)
+ h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
@@ -38,8 +38,7 @@ def forward(ctx, input_, fwd_info, bwd_info):
@staticmethod
def backward(ctx, grad_output):
-
- h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
- sync_rid = fwd_info.get('sync_rid', None)
+ sync_rid = fwd_info.get("sync_rid", None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
- h2d_rid = fwd_info.get('h2d_rid', None)
+ h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -87,8 +86,7 @@ def forward(ctx, input_, fwd_info, bwd_info):
@staticmethod
def backward(ctx, grad_output):
-
- sync_rid = ctx.bwd_info.get('sync_rid', None)
+ sync_rid = ctx.bwd_info.get("sync_rid", None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
@@ -98,7 +96,7 @@ def backward(ctx, grad_output):
else:
wait_region.move_param_to_cuda()
- h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -114,7 +112,7 @@ def backward(ctx, grad_output):
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
- '''
+ """
Convert Upload and Offload operation into runtime action.
Argument:
@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
- '''
+ """
with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
- '''
+ """
Convert Prefetch and Offload operation into runtime action.
Argument:
@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
- '''
+ """
with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload
fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
- fwd_info['h2d_rid'] = region.r_id
+ fwd_info["h2d_rid"] = region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- fwd_info['d2h_rid'] = r_idx - 1
+ fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
+ bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_upload_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)
+ )
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
- upload_apply_node = mod_graph.create_node('call_function',
- convert_fwd_upload_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, {}))
+ upload_apply_node = mod_graph.create_node(
+ "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})
+ )
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch
fwd_info = {}
if region.param_size:
- fwd_info['sync_rid'] = region.r_id
+ fwd_info["sync_rid"] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
- fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
+ fwd_info["h2d_rid"] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- fwd_info['d2h_rid'] = r_idx - 1
+ fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- bwd_info['sync_rid'] = r_idx - 1
+ bwd_info["sync_rid"] = r_idx - 1
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
- bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
+ bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_prefetch_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function",
+ convert_fwd_prefetch_bwd_offload_to_action,
+ args=(last_inp_node, fwd_info, bwd_info),
+ )
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region:
- bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
+ bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_prefetch_bwd_offload_to_action,
- args=(last_inp_node, {}, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)
+ )
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
return gm
diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py
index 161f7ff86898..a6b4904f2617 100644
--- a/colossalai/auto_parallel/offload/solver.py
+++ b/colossalai/auto_parallel/offload/solver.py
@@ -1,6 +1,6 @@
import time
-from typing import List, Dict, Type
from abc import ABC, abstractmethod
+from typing import Dict, List, Type
NOT_NVML = False
try:
@@ -10,10 +10,11 @@
import torch
from torch.fx.node import Node
+
from colossalai.utils.cuda import get_current_device
-from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region
+from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
from .util import NodeInfo, NvDevicePower
@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0,
- error_factor: float = 0.95) -> None:
-
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
self.region_list = region_list
self.error_factor: float = error_factor
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
- self.memory_budget = torch.cuda.get_device_properties(
- get_current_device()).total_memory * self.error_factor
+ self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
@@ -94,7 +90,7 @@ def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: floa
if extra_cost == 0:
# means data transfer overhead can be completely overlapped
- return (float('inf'), total_mem_saving, peak_mem_saving)
+ return (float("inf"), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
@@ -122,9 +118,7 @@ def _update_state(self, best_ts: TrainingSimulator):
self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
- def _update_node_mem_info(self,
- fwd_mem_info: Dict[Node, float],
- bwd_mem_info: Dict[Node, float]):
+ def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
"""
Update the runtime memory information of the node.
@@ -134,12 +128,10 @@ def _update_node_mem_info(self,
"""
for node, mem in fwd_mem_info.items():
- assert hasattr(node, 'node_info') and isinstance(
- node.node_info, NodeInfo)
+ assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items():
- assert hasattr(node, 'node_info') and isinstance(
- node.node_info, NodeInfo)
+ assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self):
@@ -159,12 +151,12 @@ def _extract_computing_power(self):
return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units
- elif device_name.__contains__('V100'):
+ elif device_name.__contains__("V100"):
return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units
else:
- raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
+ raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
def _profile_bandwidth(self):
"""
@@ -172,9 +164,9 @@ def _profile_bandwidth(self):
using data volumes ranging from 1KB to 1GB.
"""
- print('profiling bandwidth ......')
+ print("profiling bandwidth ......")
link_to_bandwidth = {}
- links = ['h2d', 'd2h']
+ links = ["h2d", "d2h"]
for link in links:
t_size = 1024
@@ -182,24 +174,22 @@ def _profile_bandwidth(self):
# from 1KB to 1GB
for i in range(21):
- if link == 'h2d':
- src_tensor = torch.ones(
- int(t_size), dtype=torch.int8, pin_memory=True)
- dst_tensor = torch.ones(
- (int(t_size)), dtype=torch.int8, device='cuda')
- elif link == 'd2h':
- src_tensor = torch.ones(
- int(t_size), dtype=torch.int8, device='cuda')
- dst_tensor = torch.ones(
- (int(t_size)), dtype=torch.int8, pin_memory=True)
+ if link == "h2d":
+ src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
+ dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
+ elif link == "d2h":
+ src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
+ dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
def func():
dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
- print(f'size: {t_size / 1024 ** 2:.3f} MB, '
- f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
- f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
+ print(
+ f"size: {t_size / 1024 ** 2:.3f} MB, "
+ f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
+ f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
+ )
t_size *= 2
@@ -208,10 +198,7 @@ def func():
class SynGreedySolver(Solver):
-
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0) -> None:
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None
@@ -258,7 +245,8 @@ def _call_solver(self):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
- f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
+ )
def _call_solver_l2l(self):
"""
@@ -270,7 +258,6 @@ def _call_solver_l2l(self):
region.is_syn = True
def _try_to_offload(self, offload_region: Region):
-
# record previous information
orig_need_offload = offload_region.need_offload
assert not orig_need_offload
@@ -297,23 +284,17 @@ def _eval_one_choice(self, offload_region: Region):
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
- extra_comm_cost = 2.0 * \
- ts._get_communication_overhead('h2d', offload_region.param_size)
+ extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
# the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0
- profit = self._compute_offload_profit(
- ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+ profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class AsynGreedySolver(Solver):
-
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0,
- search_window_size: int = 3):
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size
@@ -331,7 +312,7 @@ def _init_state(self):
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
- print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
+ print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
def _call_solver(self):
"""
@@ -358,18 +339,17 @@ def _call_solver(self):
best_pref_ts = None
# search when to prefetch the region offloaded
- for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
+ for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None:
continue
- temp_ts, profit = self._try_to_offload(
- host_region, region)
+ temp_ts, profit = self._try_to_offload(host_region, region)
if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit
best_pref_ts = temp_ts
- if profit[0] == float('inf'):
+ if profit[0] == float("inf"):
break
if self._compare_profit(max_prefetch_profit, max_offload_profit):
@@ -392,7 +372,8 @@ def _call_solver(self):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
- f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
+ )
region_to_region_map.clear()
@@ -452,7 +433,6 @@ def _repair_strategy(self):
peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0:
-
max_profit = (0,)
best_ts = None
undo_host_region = None
@@ -464,8 +444,7 @@ def _repair_strategy(self):
assert offload_region.need_offload
assert not offload_region.is_syn
- ts, profit = self._try_convert_to_syn_upload(host_region,
- offload_region)
+ ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
if self._compare_profit(profit, max_profit):
undo_host_region = host_region
@@ -474,7 +453,7 @@ def _repair_strategy(self):
best_ts = ts
if best_ts is None:
- raise NotImplementedError('repair error!')
+ raise NotImplementedError("repair error!")
assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True
@@ -500,17 +479,13 @@ def _eval_one_choice(self):
ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
- profit = self._compute_offload_profit(
- ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+ profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class SolverFactory:
- solvers: Dict[str, Type[Solver]] = {
- 'syn': SynGreedySolver,
- 'asyn': AsynGreedySolver
- }
+ solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
@staticmethod
def create(solver_name: str) -> Type[Solver]:
diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py
index de58023ec2d6..728d8daf9a46 100644
--- a/colossalai/auto_parallel/offload/training_simulator.py
+++ b/colossalai/auto_parallel/offload/training_simulator.py
@@ -1,7 +1,7 @@
import bisect
-from typing import List, Dict
-from collections import OrderedDict
from abc import ABC, abstractmethod
+from collections import OrderedDict
+from typing import Dict, List
from torch.fx.node import Node
@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
@@ -87,11 +84,7 @@ def _get_computing_overhead(self, flop: float) -> float:
class SynTrainingSimulator(TrainingSimulator):
-
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
@@ -115,8 +108,7 @@ def _eval_fwd_mem_per_region(self, region: Region):
self.runtime_mem += region.param_size
for node in region.nodes:
- self.runtime_mem += calculate_fwd_tmp(node) + \
- calculate_fwd_out(node)
+ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -141,18 +133,15 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
-
self.runtime_mem -= calculate_fwd_out(node)
- self.runtime_mem += node.meta['bwd_mem_tmp'] + \
- node.meta['bwd_mem_out']
+ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
- self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
- calculate_fwd_tmp(node))
+ self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -160,12 +149,14 @@ def _eval_bwd_mem_per_region(self, region: Region):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
- self.runtime_mem -= user_node.meta['bwd_mem_out']
+ self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
- raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
- f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
- f"runtime memory computed less than 0, which is miscalculated!")
+ raise ValueError(
+ f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!"
+ )
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
@@ -177,23 +168,16 @@ def _eval_bwd_mem_per_region(self, region: Region):
class AsynTrainingSimulator(TrainingSimulator):
-
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
- self.last_comp: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last parameter prefetch execution period
- self.last_h2d: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last gradient offload execution period
- self.last_d2h: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
@@ -204,10 +188,8 @@ def __init__(self,
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
- self.bwd_reg_to_offl_waiting: OrderedDict[int,
- ExecutionPeriod] = OrderedDict()
- self.bwd_reg_to_offl_freed: OrderedDict[int,
- ExecutionPeriod] = OrderedDict()
+ self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict()
+ self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
@@ -217,10 +199,8 @@ def __init__(self,
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
- self.fwd_reg_flow = torch.zeros(
- (self.region_num, self.region_num)).bool()
- self.bwd_reg_flow = torch.zeros(
- (self.region_num, self.region_num)).bool()
+ self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
+ self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
def execute(self):
"""
@@ -232,7 +212,7 @@ def execute(self):
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
- for nr in self.region_list[reg.r_id + 1:]:
+ for nr in self.region_list[reg.r_id + 1 :]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
@@ -249,8 +229,7 @@ def execute(self):
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
- self.iter_end_time = max(
- self.last_comp.end_time, self.last_d2h.end_time)
+ self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
@@ -258,10 +237,8 @@ def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
- pref_end_time = pref_start_time + \
- 2.0 * self._get_communication_overhead('h2d', region.param_size)
- pref_ep = ExecutionPeriod(
- start_time=pref_start_time, end_time=pref_end_time)
+ pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size)
+ pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
@@ -276,18 +253,16 @@ def _insert_comp_exec(self, region: Region, is_fwd: bool = True):
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
- flop_key = 'fwd_flop'
+ flop_key = "fwd_flop"
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
- flop_key = 'bwd_flop'
- comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
- region.r_id, ExecutionPeriod(0, 0)).end_time)
- comp_end_time = comp_start_time + \
- sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
- for node in region.nodes])
- comp_ep = ExecutionPeriod(
- start_time=comp_start_time, end_time=comp_end_time)
+ flop_key = "bwd_flop"
+ comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time)
+ comp_end_time = comp_start_time + sum(
+ [self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]
+ )
+ comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
@@ -297,10 +272,8 @@ def _insert_d2h_exec(self, region: Region):
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
- offl_end_time = offl_start_time + \
- self._get_communication_overhead('d2h', region.param_size)
- offl_ep = ExecutionPeriod(
- start_time=offl_start_time, end_time=offl_end_time)
+ offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size)
+ offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
@@ -332,20 +305,17 @@ def _eval_fwd_mem_per_region(self, region: Region):
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
- self.fwd_reg_flow[region.r_id,
- self.reg_buffer_to_free] = False
+ self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
- self.fwd_reg_flow[region.r_id,
- fwd_prefetch_region.r_id] = True
+ self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True
for node in region.nodes:
- self.runtime_mem += calculate_fwd_tmp(node) + \
- calculate_fwd_out(node)
+ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -354,8 +324,7 @@ def _eval_fwd_mem_per_region(self, region: Region):
if region.need_offload:
self.runtime_mem -= region.param_size
- assert len(
- self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
+ assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}"
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
@@ -398,8 +367,7 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
- self.bwd_reg_flow[region.r_id,
- self.reg_buffer_to_free] = False
+ self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
@@ -415,8 +383,7 @@ def _eval_bwd_mem_per_region(self, region: Region):
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
- self.bwd_reg_flow[region.r_id,
- bwd_prefetch_region.r_id] = True
+ self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
@@ -426,10 +393,8 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
-
self.runtime_mem -= calculate_fwd_out(node)
- self.runtime_mem += node.meta['bwd_mem_tmp'] + \
- node.meta['bwd_mem_out']
+ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
@@ -437,8 +402,7 @@ def _eval_bwd_mem_per_region(self, region: Region):
self.bwd_node_mem[node] = self.runtime_mem
- self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
- calculate_fwd_tmp(node))
+ self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -446,12 +410,14 @@ def _eval_bwd_mem_per_region(self, region: Region):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
- self.runtime_mem -= user_node.meta['bwd_mem_out']
+ self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
- raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
- f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
- f"runtime memory computed less than 0, which is miscalculated!")
+ raise ValueError(
+ f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!"
+ )
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py
index 6b010512cc9c..cb65da79c5a2 100644
--- a/colossalai/auto_parallel/offload/util.py
+++ b/colossalai/auto_parallel/offload/util.py
@@ -35,7 +35,6 @@ class NvDevicePower:
class GlobalRuntimeInfo(metaclass=SingletonMeta):
-
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
# forward
for region in region_list:
for node in region.nodes:
- runtime_mem = runtime_mem + \
- calculate_fwd_tmp(node) + calculate_fwd_out(node)
+ runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem)
# backward
bwd_deps = {}
for region in region_list.__reversed__():
for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node)
- runtime_mem = runtime_mem + \
- node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
+ runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
act_peak_mem = max(runtime_mem, act_peak_mem)
- runtime_mem = runtime_mem - \
- node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
+ runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node)
# free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes)
@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
if user_node in bwd_deps:
bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0:
- runtime_mem -= user_node.meta['bwd_mem_out']
+ runtime_mem -= user_node.meta["bwd_mem_out"]
return act_peak_mem
@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
def requires_upload_p_in_fwd(shared_reg: Region):
- return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
- and shared_reg.need_offload)
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
+ )
def requires_release_p_in_bwd(shared_reg: Region):
- return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
- and shared_reg.need_offload)
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
+ )
def requires_offload_g_in_bwd(region: Region):
diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py
index ffda58e0689f..ba290ee839d8 100644
--- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py
+++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py
@@ -14,18 +14,20 @@
shape_consistency_manager = ShapeConsistencyManager()
-def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
- target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
+def _construct_shard_meta_info(
+ node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
+) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
- origin_sharding_spec, target_sharding_spec)
+ origin_sharding_spec, target_sharding_spec
+ )
meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
- input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
+ input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo
- meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
- total_cost['backward'] * element_length,
- total_cost['total'] * element_length)
+ meta_info.compute_cost = TrainCycleItem(
+ total_cost["forward"] * element_length,
+ total_cost["backward"] * element_length,
+ total_cost["total"] * element_length,
+ )
# get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
- meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
+ meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
- meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
+ meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
return meta_info
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
- origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
- user_node_index]
+ origin_sharding_spec, target_sharding_spec = (
+ origin_spec_dict[node_index],
+ sharding_spec_dict[node_index][user_node_index],
+ )
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
- output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
+ output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
- meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
- total_cost['backward'] * element_length,
- total_cost['total'] * element_length)
+ meta_info.compute_cost = TrainCycleItem(
+ total_cost["forward"] * element_length,
+ total_cost["backward"] * element_length,
+ total_cost["total"] * element_length,
+ )
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
- meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
+ meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
- meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
+ meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
else:
# this case will be handled by shape consistency manager
- origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
- 'tgt_spec']
+ origin_sharding_spec, target_sharding_spec = (
+ comm_action.comm_spec["src_spec"],
+ comm_action.comm_spec["tgt_spec"],
+ )
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
-def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
- comm_actions_dict: Dict) -> GraphModule:
+def comm_metainfo_pass(
+ gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
+) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
- setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
+ setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
- setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
+ setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm
diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py
index 0673b767de7b..9b000549de6c 100644
--- a/colossalai/auto_parallel/passes/meta_info_prop.py
+++ b/colossalai/auto_parallel/passes/meta_info_prop.py
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
-
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
- 'placeholder': self.placeholder_handler,
- 'get_attr': self.get_attr_handler,
- 'output': self.output_handler,
- 'call_function': self.node_handler,
- 'call_module': self.node_handler,
- 'call_method': self.node_handler,
+ "placeholder": self.placeholder_handler,
+ "get_attr": self.get_attr_handler,
+ "output": self.output_handler,
+ "call_function": self.node_handler,
+ "call_module": self.node_handler,
+ "call_method": self.node_handler,
}
def _set_data_ptr(self, x):
@@ -46,7 +45,7 @@ def _is_inplace(self, node: Node):
"""
Check if the node is inplace operation.
"""
- if node.op == 'call_module':
+ if node.op == "call_module":
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
@@ -66,7 +65,7 @@ def placeholder_handler(self, node: Node) -> None:
Handle the placeholder node.
"""
graph_info = GraphInfo()
- out = _normalize_tuple(getattr(node, '_meta_data', None))
+ out = _normalize_tuple(getattr(node, "_meta_data", None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@@ -96,7 +95,7 @@ def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
- assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
+ assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
@@ -126,7 +125,8 @@ def node_handler(self, node: Node) -> None:
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
- (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
+ (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
+ )
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py
index 2049a06187d2..27afe72c0db8 100644
--- a/colossalai/auto_parallel/passes/runtime_apply_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py
@@ -1,18 +1,10 @@
-from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- OperationData,
- OperationDataType,
- TrainCycleItem,
-)
-from colossalai.device.device_mesh import DeviceMesh
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
-def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
- user_node_index: int):
+def runtime_apply_for_iterable_object(
+ node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
+):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
- for index, (origin_sharding_spec,
- target_sharding_spec) in enumerate(zip(origin_dict[node_index],
- input_dict[node_index][user_node_index])):
+ for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
+ zip(origin_dict[node_index], input_dict[node_index][user_node_index])
+ ):
rst.append(
- shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
- target_sharding_spec))
+ shape_consistency_manager.apply_for_autoparallel_runtime(
+ node[index], origin_sharding_spec, target_sharding_spec
+ )
+ )
rst = type(node)(rst)
return rst
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
- origin_sharding_spec = comm_action.comm_spec['src_spec']
- tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
+ origin_sharding_spec = comm_action.comm_spec["src_spec"]
+ tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict = {}
index = 0
for node in nodes:
- if node.target == 'sharding_spec_convert_dict':
+ if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
- if node.target == 'origin_node_sharding_spec_dict':
+ if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
- if node.target == 'comm_actions_dict':
+ if node.target == "comm_actions_dict":
comm_actions_dict_node = node
continue
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
- if not hasattr(node, 'best_strategy') or node.op == 'output':
+ if not hasattr(node, "best_strategy") or node.op == "output":
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
- node.target_sharding_specs,
- (list,
- tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
+ node.target_sharding_specs, (list, tuple)
+ ), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
total_difference = 0
- for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
- node.target_sharding_specs[user_node_index]):
+ for sharding_spec, target_sharding_spec in zip(
+ node.sharding_spec, node.target_sharding_specs[user_node_index]
+ ):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function',
- runtime_apply_for_iterable_object,
- args=(node, origin_dict_node, input_dict_node,
- node_to_index_dict[node], user_node_index))
+ shape_consistency_node = mod_graph.create_node(
+ "call_function",
+ runtime_apply_for_iterable_object,
+ args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
+ )
else:
- assert isinstance(node.sharding_spec,
- ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
+ assert isinstance(
+ node.sharding_spec, ShardingSpec
+ ), "node.sharding_spec should be type of ShardingSpec, tuple or list."
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function',
- runtime_apply,
- args=(node, origin_dict_node, input_dict_node,
- node_to_index_dict[node], user_node_index))
- if hasattr(user_node.meta['info'], 'activation_checkpoint'):
- MetaInfo(shape_consistency_node,
- mod_dir=user_node.meta['info'].mod_dir,
- activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
+ shape_consistency_node = mod_graph.create_node(
+ "call_function",
+ runtime_apply,
+ args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
+ )
+ if hasattr(user_node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ shape_consistency_node,
+ mod_dir=user_node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
+ )
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
- if not hasattr(node, 'best_strategy') or node.op == 'output':
+ if not hasattr(node, "best_strategy") or node.op == "output":
continue
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
-
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
- comm_spec_apply_node = mod_graph.create_node('call_function',
- runtime_comm_spec_apply,
- args=(comm_object, comm_actions_dict_node,
- node_to_index_dict[node], op_data.name))
+ comm_spec_apply_node = mod_graph.create_node(
+ "call_function",
+ runtime_comm_spec_apply,
+ args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
+ )
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
- comm_spec_apply_node = mod_graph.create_node('call_function',
- runtime_comm_spec_apply,
- args=(node, comm_actions_dict_node,
- node_to_index_dict[node], op_data.name))
+ comm_spec_apply_node = mod_graph.create_node(
+ "call_function",
+ runtime_comm_spec_apply,
+ args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
+ )
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
@@ -211,10 +212,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
- if hasattr(node.meta['info'], 'activation_checkpoint'):
- MetaInfo(comm_spec_apply_node,
- mod_dir=node.meta['info'].mod_dir,
- activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
+ if hasattr(node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ comm_spec_apply_node,
+ mod_dir=node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
+ )
return gm
@@ -227,21 +230,21 @@ def _act_annotation_pass(gm: torch.fx.GraphModule):
nodes = tuple(mod_graph.nodes)
for node in nodes:
- if not hasattr(node.meta, 'activation_checkpoint'):
- from .runtime_preparation_pass import size_processing
+ if not hasattr(node.meta, "activation_checkpoint"):
+ pass
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
- if 'activation_checkpoint' in user_node.meta:
- user_act_annotation = user_node.meta['activation_checkpoint']
+ if "activation_checkpoint" in user_node.meta:
+ user_act_annotation = user_node.meta["activation_checkpoint"]
break
for input_node in node._input_nodes.keys():
- if 'activation_checkpoint' in input_node.meta:
- input_act_annotation = input_node.meta['activation_checkpoint']
+ if "activation_checkpoint" in input_node.meta:
+ input_act_annotation = input_node.meta["activation_checkpoint"]
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
- node.meta['activation_checkpoint'] = user_act_annotation
+ node.meta["activation_checkpoint"] = user_act_annotation
return gm
diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index 0ed0742ee57e..65c3d8e0cbeb 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -1,19 +1,12 @@
import operator
-from copy import deepcopy
from typing import Dict, List, Union
import torch
-from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- OperationDataType,
- ShardingStrategy,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
@@ -25,11 +18,13 @@
shape_consistency_manager = ShapeConsistencyManager()
-def size_processing(size: Union[int, torch.Size],
- dim_partition_dict: Dict[int, List[int]],
- device_mesh_info: Dict[int, int],
- target_dim: int = None,
- node_name: str = None):
+def size_processing(
+ size: Union[int, torch.Size],
+ dim_partition_dict: Dict[int, List[int]],
+ device_mesh_info: Dict[int, int],
+ target_dim: int = None,
+ node_name: str = None,
+):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return size
-def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
- strategies_constructor: StrategiesConstructor):
+def solution_annotation_pass(
+ gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
+):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
@@ -70,14 +66,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node
- setattr(node, 'best_strategy', strategies_vector[strategy_index])
- setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
+ setattr(node, "best_strategy", strategies_vector[strategy_index])
+ setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
- str(node))
+ str(node)
+ )
# attach the corresponding metainfo if node has the attribute `strategies_info`
- if hasattr(node, 'strategies_info'):
- setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
+ if hasattr(node, "strategies_info"):
+ setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
@@ -92,15 +89,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
- setattr(node, 'target_sharding_specs', target_sharding_specs)
+ setattr(node, "target_sharding_specs", target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
- if node.op == 'get_attr':
- assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
+ if node.op == "get_attr":
+ assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node)
- if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
+ if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy
@@ -122,11 +119,11 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph
for node in nodes:
- if node.op != 'placeholder':
+ if node.op != "placeholder":
with mod_graph.inserting_before(node):
- input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
- origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
- comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
+ input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
+ origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
+ comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
@@ -148,7 +145,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
- '''
+ """
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
@@ -156,7 +153,7 @@ def _extract_target_dim(node):
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
- '''
+ """
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
@@ -165,19 +162,21 @@ def _extract_target_dim(node):
return target_dim
def _post_processing(node, size_processing_node):
- '''
+ """
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
- '''
+ """
# store original node and processing node pair in node_pairs dictionary
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
- if hasattr(node.meta['info'], 'activation_checkpoint'):
- MetaInfo(size_processing_node,
- mod_dir=node.meta['info'].mod_dir,
- activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
+ if hasattr(node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ size_processing_node,
+ mod_dir=node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
+ )
user_list = list(node.users.keys())
for user in user_list:
@@ -196,10 +195,10 @@ def _post_processing(node, size_processing_node):
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
- '''
+ """
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
- '''
+ """
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
@@ -220,8 +219,7 @@ def _update_slice_object_args(slice_object):
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
-
- if node.op == 'call_method' and node.target == 'size':
+ if node.op == "call_method" and node.target == "size":
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
@@ -232,14 +230,14 @@ def _update_slice_object_args(slice_object):
# insert size_processing node
with mod_graph.inserting_after(node):
- size_processing_node = mod_graph.create_node('call_function',
- size_processing,
- args=(node, dim_partition_dict, device_mesh_info,
- target_dim, node.name))
+ size_processing_node = mod_graph.create_node(
+ "call_function",
+ size_processing,
+ args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
+ )
_post_processing(node, size_processing_node)
- if node.op == 'call_function' and node.target == operator.getitem:
-
+ if node.op == "call_function" and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
- '''
+ """
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
- '''
+ """
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
- assert isinstance(sharding_spec,
- (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
+ assert isinstance(
+ sharding_spec, (tuple, list)
+ ), "sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
@@ -322,8 +321,9 @@ def _process_node_arguments(node):
else:
new_args.append(arg)
else:
- assert isinstance(arg,
- (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
+ assert isinstance(
+ arg, (int, tuple, list)
+ ), "The argument in view node should be either type of Node or int."
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
@@ -332,7 +332,7 @@ def _process_node_arguments(node):
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
- if node.op == 'call_method':
+ if node.op == "call_method":
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
@@ -350,7 +350,7 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
args_to_process = tuple(args_to_process)
- if node.op == 'call_method':
+ if node.op == "call_method":
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
@@ -358,9 +358,9 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
node.args = new_args
def _filter_node_with_shape_args(node):
- if node.op == 'call_method':
+ if node.op == "call_method":
target = getattr(node.args[0]._meta_data.__class__, node.target)
- elif node.op == 'call_function':
+ elif node.op == "call_function":
target = node.target
else:
target = None
@@ -371,7 +371,7 @@ def _filter_node_with_shape_args(node):
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
- if not hasattr(node, 'sharding_spec'):
+ if not hasattr(node, "sharding_spec"):
continue
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
@@ -392,15 +392,21 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param, name=None):
-
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action, name):
-
- if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
+ if (
+ node.op == "call_module"
+ and op_data.type == OperationDataType.PARAM
+ and op_data.name == name
+ and comm_action.comm_type == CommType.HOOK
+ ):
return True
- if node.op == 'get_attr' and isinstance(
- node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
+ if (
+ node.op == "get_attr"
+ and isinstance(node._meta_data, torch.nn.parameter.Parameter)
+ and comm_action.comm_type == CommType.HOOK
+ ):
return True
return False
@@ -410,7 +416,6 @@ def _filter_param_to_hook(node, op_data, comm_action, name):
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap):
-
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
@@ -426,22 +431,26 @@ def _shard_param(param, target_sharding_spec):
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
- setattr(param, 'sharding_spec', origin_sharding_spec)
+ setattr(param, "sharding_spec", origin_sharding_spec)
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
- shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
- target_sharding_spec).detach().clone())
+ shape_consistency_manager.apply_for_autoparallel_runtime(
+ param.data, param.sharding_spec, target_sharding_spec
+ )
+ .detach()
+ .clone()
+ )
return param
for node in nodes:
- if node.op == 'call_module':
+ if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
- if hasattr(target_module, 'processed') and target_module.processed:
+ if hasattr(target_module, "processed") and target_module.processed:
continue
- setattr(target_module, 'processed', True)
+ setattr(target_module, "processed", True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
param = _shard_param(param, target_sharding_spec)
@@ -453,7 +462,7 @@ def _shard_param(param, target_sharding_spec):
# apply the sharding spec of buffers
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
- setattr(buffer, 'sharding_spec', origin_sharding_spec)
+ setattr(buffer, "sharding_spec", origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
@@ -461,7 +470,7 @@ def _shard_param(param, target_sharding_spec):
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
- if node.op == 'get_attr':
+ if node.op == "get_attr":
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
- pass
-def runtime_preparation_pass(gm: torch.fx.GraphModule,
- solution: List[int],
- device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor,
- overlap=False):
+def runtime_preparation_pass(
+ gm: torch.fx.GraphModule,
+ solution: List[int],
+ device_mesh: DeviceMesh,
+ strategies_constructor: StrategiesConstructor,
+ overlap=False,
+):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
- gm, solution, strategies_constructor)
+ gm, solution, strategies_constructor
+ )
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py
index 99c124934060..e9c2c8664a61 100644
--- a/colossalai/auto_parallel/tensor_shard/constants.py
+++ b/colossalai/auto_parallel/tensor_shard/constants.py
@@ -3,9 +3,22 @@
import torch
__all__ = [
- 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
- 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
- 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
+ "ELEMENTWISE_MODULE_OP",
+ "ELEMENTWISE_FUNC_OP",
+ "RESHAPE_FUNC_OP",
+ "CONV_MODULE_OP",
+ "CONV_FUNC_OP",
+ "LINEAR_MODULE_OP",
+ "LINEAR_FUNC_OP",
+ "BATCHNORM_MODULE_OP",
+ "POOL_MODULE_OP",
+ "NON_PARAM_FUNC_OP",
+ "BCAST_FUNC_OP",
+ "EMBEDDING_MODULE_OP",
+ "LAYERNORM_MODULE_OP",
+ "ELEMENTWISE_METHOD_OP",
+ "RESHAPE_METHOD_OP",
+ "INFINITY_COST",
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
@@ -18,13 +31,13 @@
torch.nn.functional.relu,
torch.nn.functional.dropout,
# softmax should not be here
- torch.nn.functional.softmax
+ torch.nn.functional.softmax,
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
- torch.Tensor.contiguous
+ torch.Tensor.contiguous,
]
RESHAPE_FUNC_OP = [
torch.flatten,
@@ -42,15 +55,36 @@
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
- torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
- operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow
+ torch.add,
+ torch.sub,
+ torch.mul,
+ torch.div,
+ torch.floor_divide,
+ torch.true_divide,
+ operator.add,
+ operator.sub,
+ operator.mul,
+ operator.floordiv,
+ operator.truediv,
+ torch.matmul,
+ operator.pow,
+ torch.pow,
]
CONV_MODULE_OP = [
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
- torch.nn.ConvTranspose3d
+ torch.nn.Conv1d,
+ torch.nn.Conv2d,
+ torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d,
+ torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d,
]
CONV_FUNC_OP = [
- torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
+ torch.conv1d,
+ torch.conv2d,
+ torch.conv3d,
+ torch.conv_transpose1d,
+ torch.conv_transpose2d,
+ torch.conv_transpose3d,
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
@@ -85,7 +119,7 @@
operator.floordiv,
operator.truediv,
# softmax should not be here
- torch.nn.functional.softmax
+ torch.nn.functional.softmax,
]
INFINITY_COST = 1e13
diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py
index b406ca6fb7e0..d82f0ef53f66 100644
--- a/colossalai/auto_parallel/tensor_shard/initialize.py
+++ b/colossalai/auto_parallel/tensor_shard/initialize.py
@@ -3,7 +3,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
-from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
@@ -14,27 +13,32 @@
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
-from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
+from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module):
- '''
+ """
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function.
- '''
-
- def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
- origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
- '''
+ """
+
+ def __init__(
+ self,
+ module: ColoGraphModule,
+ sharding_spec_dict: Dict[int, List[ShardingSpec]],
+ origin_spec_dict: Dict[int, ShardingSpec],
+ comm_actions_dict: Dict[int, Dict[str, CommAction]],
+ ):
+ """
Args:
module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
- '''
+ """
super(ModuleWrapper, self).__init__()
self.module = module
self.sharding_spec_dict = sharding_spec_dict
@@ -42,67 +46,68 @@ def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[S
self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs):
- return self.module(*args,
- sharding_spec_convert_dict=self.sharding_spec_dict,
- origin_node_sharding_spec_dict=self.origin_spec_dict,
- comm_actions_dict=self.comm_actions_dict,
- **kwargs)
+ return self.module(
+ *args,
+ sharding_spec_convert_dict=self.sharding_spec_dict,
+ origin_node_sharding_spec_dict=self.origin_spec_dict,
+ comm_actions_dict=self.comm_actions_dict,
+ **kwargs,
+ )
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
- '''
+ """
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
- '''
+ """
# TODO: implement this function
- pass
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
- '''
+ """
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
- '''
+ """
# TODO: implement this function
- pass
-def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
- shard_option: str):
- '''
+def build_strategy_constructor(
+ graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str
+):
+ """
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
- '''
- if solver_preference == 'standard':
+ """
+ if solver_preference == "standard":
solver_preference = SolverPerference.STANDARD
- elif solver_preference == 'tp':
+ elif solver_preference == "tp":
solver_preference = SolverPerference.TP
- elif solver_preference == 'dp':
+ elif solver_preference == "dp":
solver_preference = SolverPerference.DP
else:
- raise ValueError(f'Invalid solver_preference: {solver_preference}')
+ raise ValueError(f"Invalid solver_preference: {solver_preference}")
- if dataloader_option == 'replicated':
+ if dataloader_option == "replicated":
dataloader_option = DataloaderOption.REPLICATED
- elif dataloader_option == 'distributed':
+ elif dataloader_option == "distributed":
dataloader_option = DataloaderOption.DISTRIBUTED
else:
- raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
+ raise ValueError(f"Invalid dataloader_option: {dataloader_option}")
- if shard_option == 'standard':
+ if shard_option == "standard":
shard_option = ShardOption.STANDARD
- elif shard_option == 'shard':
+ elif shard_option == "shard":
shard_option = ShardOption.SHARD
- elif shard_option == 'shard_last_axis':
+ elif shard_option == "shard_last_axis":
shard_option = ShardOption.SHARD_LAST_AXIS
- elif shard_option == 'full_shard':
+ elif shard_option == "full_shard":
shard_option = ShardOption.FULL_SHARD
else:
- raise ValueError(f'Invalid shard_option: {shard_option}')
+ raise ValueError(f"Invalid shard_option: {shard_option}")
- solver_options = SolverOptions(solver_perference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option)
+ solver_options = SolverOptions(
+ solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option
+ )
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
@@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
- '''
+ """
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
- '''
+ """
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
@@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution
-def transform_to_sharded_model(gm: ColoGraphModule,
- meta_args: Dict,
- solution: List[int],
- device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor,
- overlap: bool = False):
- '''
+def transform_to_sharded_model(
+ gm: ColoGraphModule,
+ meta_args: Dict,
+ solution: List[int],
+ device_mesh: DeviceMesh,
+ strategies_constructor: StrategiesConstructor,
+ overlap: bool = False,
+):
+ """
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
- '''
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
- solution,
- device_mesh,
- strategies_constructor,
- overlap=overlap)
+ """
+ gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
+ gm, solution, device_mesh, strategies_constructor, overlap=overlap
+ )
gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile()
@@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule,
return gm, sharding_spec_dicts
-def initialize_device_mesh(world_size: int = -1,
- physical_devices: List[int] = None,
- alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
- logical_mesh_shape: Tuple[int] = None,
- logical_mesh_id: torch.Tensor = None):
- '''
+def initialize_device_mesh(
+ world_size: int = -1,
+ physical_devices: List[int] = None,
+ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
+ logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None,
+):
+ """
This method is used to initialize the device mesh.
Args:
@@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
- '''
+ """
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
@@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1,
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
- device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
- logical_mesh_id=logical_mesh_id,
- mesh_alpha=mesh_alpha,
- mesh_beta=mesh_beta,
- init_process_group=True)
+ device_mesh = DeviceMesh(
+ physical_mesh_id=physical_mesh,
+ logical_mesh_id=logical_mesh_id,
+ mesh_alpha=mesh_alpha,
+ mesh_beta=mesh_beta,
+ init_process_group=True,
+ )
return device_mesh
-def initialize_model(model: nn.Module,
- meta_args: Dict[str, torch.Tensor],
- device_mesh: DeviceMesh,
- memory_budget: float = -1.0,
- overlap: bool = False,
- solver_preference: str = 'standard',
- dataloader_option: str = 'replicated',
- shard_option: str = 'standard',
- save_solver_solution: bool = False,
- load_solver_solution: bool = False,
- solution_path: str = None,
- return_solution: bool = False):
- '''
+def initialize_model(
+ model: nn.Module,
+ meta_args: Dict[str, torch.Tensor],
+ device_mesh: DeviceMesh,
+ memory_budget: float = -1.0,
+ overlap: bool = False,
+ solver_preference: str = "standard",
+ dataloader_option: str = "replicated",
+ shard_option: str = "standard",
+ save_solver_solution: bool = False,
+ load_solver_solution: bool = False,
+ solution_path: str = None,
+ return_solution: bool = False,
+):
+ """
This method is used to initialize the sharded model which could be used as normal pytorch model.
Args:
@@ -246,7 +257,7 @@ def initialize_model(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
- '''
+ """
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
@@ -256,11 +267,13 @@ def initialize_model(model: nn.Module,
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
- strategies_constructor = build_strategy_constructor(graph,
- device_mesh,
- solver_preference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option)
+ strategies_constructor = build_strategy_constructor(
+ graph,
+ device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option,
+ )
if load_solver_solution:
solution = torch.load(solution_path)
else:
@@ -268,8 +281,9 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
- gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
- overlap)
+ gm, sharding_spec_dicts = transform_to_sharded_model(
+ gm, meta_args, solution, device_mesh, strategies_constructor, overlap
+ )
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
@@ -277,28 +291,30 @@ def initialize_model(model: nn.Module,
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
- solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
+ solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}")
return model_to_return, solution_to_return
else:
return model_to_return
-def autoparallelize(model: nn.Module,
- meta_args: Dict[str, torch.Tensor] = None,
- data_loader: torch.utils.data.DataLoader = None,
- data_process_func: callable = None,
- alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
- logical_mesh_shape: Tuple[int] = None,
- logical_mesh_id: torch.Tensor = None,
- solver_preference: str = 'standard',
- dataloader_option: str = 'replicated',
- shard_option: str = 'standard',
- save_solver_solution: bool = False,
- load_solver_solution: bool = False,
- solver_solution_path: str = None,
- return_solution: bool = False,
- memory_budget: float = -1.0):
- '''
+def autoparallelize(
+ model: nn.Module,
+ meta_args: Dict[str, torch.Tensor] = None,
+ data_loader: torch.utils.data.DataLoader = None,
+ data_process_func: callable = None,
+ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
+ logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None,
+ solver_preference: str = "standard",
+ dataloader_option: str = "replicated",
+ shard_option: str = "standard",
+ save_solver_solution: bool = False,
+ load_solver_solution: bool = False,
+ solver_solution_path: str = None,
+ return_solution: bool = False,
+ memory_budget: float = -1.0,
+):
+ """
This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model.
@@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
- '''
- device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
- logical_mesh_shape=logical_mesh_shape,
- logical_mesh_id=logical_mesh_id)
+ """
+ device_mesh = initialize_device_mesh(
+ alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id
+ )
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
- rst_to_unpack = initialize_model(model,
- meta_args,
- device_mesh,
- solver_preference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option,
- save_solver_solution=save_solver_solution,
- load_solver_solution=load_solver_solution,
- solution_path=solver_solution_path,
- return_solution=return_solution,
- memory_budget=memory_budget)
+ rst_to_unpack = initialize_model(
+ model,
+ meta_args,
+ device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option,
+ save_solver_solution=save_solver_solution,
+ load_solver_solution=load_solver_solution,
+ solution_path=solver_solution_path,
+ return_solution=return_solution,
+ memory_budget=memory_budget,
+ )
if return_solution:
model, solution = rst_to_unpack
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
index 9903ca54e52c..aa2e5e9c40c0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
@@ -25,11 +25,33 @@
from .where_handler import WhereHandler
__all__ = [
- 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
- 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
- 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
- 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
- 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
- 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
- 'SplitHandler'
+ "LinearFunctionHandler",
+ "LinearModuleHandler",
+ "BMMFunctionHandler",
+ "AddBMMFunctionHandler",
+ "LayerNormModuleHandler",
+ "BatchNormModuleHandler",
+ "ConvModuleHandler",
+ "ConvFunctionHandler",
+ "UnaryElementwiseHandler",
+ "DefaultReshapeHandler",
+ "PlaceholderHandler",
+ "OutputHandler",
+ "WhereHandler",
+ "NormPoolingHandler",
+ "BinaryElementwiseHandler",
+ "MatMulHandler",
+ "operator_registry",
+ "ADDMMFunctionHandler",
+ "GetItemHandler",
+ "GetattrHandler",
+ "ViewHandler",
+ "PermuteHandler",
+ "TensorConstructorHandler",
+ "EmbeddingModuleHandler",
+ "EmbeddingFunctionHandler",
+ "SumHandler",
+ "SoftmaxHandler",
+ "TransposeHandler",
+ "SplitHandler",
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
index da0d199c5e05..47c654d6aa43 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
@@ -2,15 +2,13 @@
import torch
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
-
-from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
-__all__ = ['ADDMMFunctionHandler']
+__all__ = ["ADDMMFunctionHandler"]
@operator_registry.register(torch.addmm)
@@ -30,25 +28,26 @@ def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType:
return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
-
# input operand
input_data = self.node.args[1]._meta_data
- physical_input_operand = OperationData(name=str(self.node.args[1]),
- type=self._infer_op_data_type(input_data),
- data=input_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data
+ )
# other operand
other_data = self.node.args[2]._meta_data
- physical_other_operand = OperationData(name=str(self.node.args[2]),
- type=self._infer_op_data_type(other_data),
- data=other_data)
+ physical_other_operand = OperationData(
+ name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data
+ )
# bias physical shape
bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data
- physical_bias_operand = OperationData(name=str(self.node.args[0]),
- type=self._infer_op_data_type(bias_data),
- data=bias_data,
- logical_shape=bias_logical_shape)
+ physical_bias_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=self._infer_op_data_type(bias_data),
+ data=bias_data,
+ logical_shape=bias_logical_shape,
+ )
# output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
@@ -57,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
- 'bias': physical_bias_operand
+ "bias": physical_bias_operand,
}
return mapping
@@ -66,26 +65,27 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm")
+ )
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
- bias_op_data = op_data_mapping['bias']
+ bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- bias_sharding_spec, bias_logical_shape, bias_physical_shape)
+ bias_sharding_spec, bias_logical_shape, bias_physical_shape
+ )
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=bias_op_data,
- sharding_spec=bias_sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
+ )
strategy.communication_actions[bias_op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
index cb1bb36b7879..df4b1d6cef3f 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
@@ -2,12 +2,12 @@
import torch
-from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from ..sharding_strategy import OperationData, OperationDataType
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
-__all__ = ['BatchNormModuleHandler']
+__all__ = ["BatchNormModuleHandler"]
@operator_registry.register(torch.nn.BatchNorm1d)
@@ -27,30 +27,37 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape,
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
- physical_running_mean_operand = OperationData(name="running_mean",
- type=OperationDataType.BUFFER,
- data=self.named_buffers['running_mean'],
- logical_shape=self.named_buffers['running_mean'].shape)
+ physical_running_mean_operand = OperationData(
+ name="running_mean",
+ type=OperationDataType.BUFFER,
+ data=self.named_buffers["running_mean"],
+ logical_shape=self.named_buffers["running_mean"].shape,
+ )
- physical_running_var_operand = OperationData(name="running_var",
- type=OperationDataType.BUFFER,
- data=self.named_buffers['running_var'],
- logical_shape=self.named_buffers['running_var'].shape)
+ physical_running_var_operand = OperationData(
+ name="running_var",
+ type=OperationDataType.BUFFER,
+ data=self.named_buffers["running_var"],
+ logical_shape=self.named_buffers["running_var"].shape,
+ )
physical_num_batches_tracked_operand = OperationData(
name="num_batches_tracked",
type=OperationDataType.BUFFER,
- data=self.named_buffers['num_batches_tracked'],
- logical_shape=self.named_buffers['num_batches_tracked'].shape)
+ data=self.named_buffers["num_batches_tracked"],
+ logical_shape=self.named_buffers["num_batches_tracked"].shape,
+ )
mapping = {
"input": physical_input_operand,
@@ -58,12 +65,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"output": physical_output,
"running_mean": physical_running_mean_operand,
"running_var": physical_running_var_operand,
- "num_batches_tracked": physical_num_batches_tracked_operand
+ "num_batches_tracked": physical_num_batches_tracked_operand,
}
- if self.named_parameters['bias'] is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if self.named_parameters["bias"] is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
index db8f0b54ddee..f8c137348353 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
@@ -4,15 +4,14 @@
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
-__all__ = ['BinaryElementwiseHandler']
+__all__ = ["BinaryElementwiseHandler"]
@operator_registry.register(BCAST_FUNC_OP)
@@ -38,7 +37,7 @@ def _get_arg_value(idx):
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
- meta_data = torch.Tensor([meta_data]).to('meta')
+ meta_data = torch.Tensor([meta_data]).to("meta")
non_tensor = True
else:
@@ -46,7 +45,7 @@ def _get_arg_value(idx):
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
- meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
+ meta_data = torch.Tensor([self.node.args[idx]]).to("meta")
non_tensor = True
return meta_data, non_tensor
@@ -58,24 +57,27 @@ def _get_arg_value(idx):
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
- input_op_data = OperationData(name=str(self.node.args[0]),
- type=_get_op_data_type(input_meta_data),
- data=input_meta_data,
- logical_shape=bcast_shape)
- other_op_data = OperationData(name=str(self.node.args[1]),
- type=_get_op_data_type(other_meta_data),
- data=other_meta_data,
- logical_shape=bcast_shape)
- output_op_data = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=bcast_shape)
+ input_op_data = OperationData(
+ name=str(self.node.args[0]),
+ type=_get_op_data_type(input_meta_data),
+ data=input_meta_data,
+ logical_shape=bcast_shape,
+ )
+ other_op_data = OperationData(
+ name=str(self.node.args[1]),
+ type=_get_op_data_type(other_meta_data),
+ data=other_meta_data,
+ logical_shape=bcast_shape,
+ )
+ output_op_data = OperationData(
+ name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape
+ )
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
- mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
+ mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
@@ -100,14 +102,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- sharding_spec, logical_shape, physical_shape)
+ sharding_spec, logical_shape, physical_shape
+ )
strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=op_data,
- sharding_spec=sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec
+ )
strategy.communication_actions[op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
index da2b733c9f7a..5c22ac7bef11 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
@@ -2,15 +2,13 @@
import torch
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
-
-from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
-__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler']
+__all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"]
def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
@@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
node handler to reduce code redundancy.
"""
# input operand
- physical_input_operand = OperationData(name=str(node.args[input_idx]),
- type=OperationDataType.ARG,
- data=node.args[input_idx]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data
+ )
# other operand
- physical_other_operand = OperationData(name=str(node.args[other_idx]),
- type=OperationDataType.ARG,
- data=node.args[other_idx]._meta_data)
+ physical_other_operand = OperationData(
+ name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data
+ )
# output
physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data)
@@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
if bias_idx is not None:
# bias physical shape
bias_logical_shape = node._meta_data.shape
- physical_bias_operand = OperationData(name=str(node.args[bias_idx]),
- type=OperationDataType.ARG,
- data=node.args[bias_idx]._meta_data,
- logical_shape=bias_logical_shape)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(node.args[bias_idx]),
+ type=OperationDataType.ARG,
+ data=node.args[bias_idx]._meta_data,
+ logical_shape=bias_logical_shape,
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
@@ -91,20 +91,20 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
- if 'bias' in op_data_mapping:
- bias_op_data = op_data_mapping['bias']
+ if "bias" in op_data_mapping:
+ bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- bias_sharding_spec, bias_logical_shape, bias_physical_shape)
+ bias_sharding_spec, bias_logical_shape, bias_physical_shape
+ )
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=bias_op_data,
- sharding_spec=bias_sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
+ )
strategy.communication_actions[bias_op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
index 272b1c85630a..fd7c1f837a5a 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
@@ -3,13 +3,13 @@
import torch
import torch.nn.functional as F
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import transpose_partition_dim
-from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
+from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
-__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
+__all__ = ["ConvModuleHandler", "ConvFunctionHandler"]
@operator_registry.register(torch.nn.Conv1d)
@@ -29,25 +29,29 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
logical_shape_for_weight = list(self.named_parameters["weight"].shape)
- logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
- 1], logical_shape_for_weight[0]
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=torch.Size(logical_shape_for_weight))
+ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
+ logical_shape_for_weight[1],
+ logical_shape_for_weight[0],
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=torch.Size(logical_shape_for_weight),
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.named_parameters:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
@@ -77,9 +81,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -88,26 +92,30 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
data_type = OperationDataType.ARG
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
- logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
- 1], logical_shape_for_weight[0]
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data,
- logical_shape=torch.Size(logical_shape_for_weight))
+ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
+ logical_shape_for_weight[1],
+ logical_shape_for_weight[0],
+ )
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]),
+ type=data_type,
+ data=self.node.args[1]._meta_data,
+ logical_shape=torch.Size(logical_shape_for_weight),
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
+ if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
- physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
- type=data_type,
- data=self.node.kwargs["bias"]._meta_data)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
index 0c5b9f39e1fb..feb1032a6c0f 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import DefaultReshapeGenerator, StrategyGenerator
-__all__ = ['DefaultReshapeHandler']
+__all__ = ["DefaultReshapeHandler"]
@operator_registry.register(torch.flatten)
@@ -54,17 +54,15 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=data_type,
- data=input_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape
+ )
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape
+ )
mapping = {"input": physical_input_operand, "output": physical_output}
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
index 112ee194b4ec..f29c3a0b7d5d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
@@ -12,11 +12,12 @@
from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
-__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
+__all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"]
-def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
- output_name: str) -> List[ShardingStrategy]:
+def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy: ShardingStrategy, input_name: str, output_name: str
+) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
@@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping={0: i},
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping={0: i},
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {0: i}
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
- strategy_copy.name = f'{strategy.name}_{i}'
+ strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
- f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
+ f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
@@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping={},
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True
+ )
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {}
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
sharding_strategies.append(strategy_copy)
return sharding_strategies
@@ -125,14 +131,16 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=input_meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=input_meta_data,
+ logical_shape=input_logical_shape,
+ )
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'])
+ physical_other_operand = OperationData(
+ name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"]
+ )
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
@@ -141,10 +149,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=output_meta_data,
+ logical_shape=output_logical_shape,
+ )
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
@@ -157,10 +167,9 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
- input_name=str(
- self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
@@ -183,10 +192,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.node.args[0]._meta_data,
+ logical_shape=input_logical_shape,
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -194,9 +205,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
else:
data_type = OperationDataType.ARG
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data)
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data
+ )
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
@@ -223,8 +234,7 @@ def post_process(self, strategy: ShardingStrategy):
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
- input_name=str(
- self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
index 53addb873d1d..dcf0a1760a2c 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
@@ -4,7 +4,7 @@
from .node_handler import NodeHandler
from .strategy import GetattrGenerator, StrategyGenerator
-__all__ = ['GetattrHandler']
+__all__ = ["GetattrHandler"]
class GetattrHandler(NodeHandler):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
index 3466e9dd9940..bd342c12eda9 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
@@ -8,7 +8,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
-__all__ = ['GetItemHandler']
+__all__ = ["GetItemHandler"]
@operator_registry.register(operator.getitem)
@@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
index 452381169b74..ce6b20fa1d24 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
-__all__ = ['LayerNormModuleHandler']
+__all__ = ["LayerNormModuleHandler"]
@operator_registry.register(torch.nn.LayerNorm)
@@ -25,20 +25,22 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape,
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if self.named_parameters['bias'] is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if self.named_parameters["bias"] is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
index ea541e434009..4177af4eaf71 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
@@ -3,24 +3,21 @@
import torch
import torch.nn.functional as F
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_sharding_spec_validity,
- transpose_partition_dim,
- update_partition_dim,
-)
+from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
-from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
+from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
-__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
+__all__ = ["LinearModuleHandler", "LinearFunctionHandler"]
-def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
- weight_name: str) -> ShardingStrategy:
+def _update_sharding_spec_for_transposed_weight_for_linear(
+ strategy: ShardingStrategy, weight_name: str
+) -> ShardingStrategy:
"""
This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partition spec.
@@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
- assert op_data.logical_shape[0] == op_data.data.shape[1] and \
- op_data.logical_shape[1] == op_data.data.shape[0], \
- "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
+ assert (
+ op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0]
+ ), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy
-def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
- output_name: str) -> List[ShardingStrategy]:
+def _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy: ShardingStrategy, input_name: str, output_name: str
+) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec.
@@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping=input_dim_mapping,
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping=input_dim_mapping,
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=output_dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
- strategy_copy.name = f'{strategy.name}_{i}'
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=output_dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
+ strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
- f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
+ f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
@@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping=input_dim_mapping,
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping=input_dim_mapping,
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=output_dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=output_dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
sharding_strategies.append(strategy_copy)
return sharding_strategies
@@ -152,10 +158,13 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping,
- self.device_mesh,
- linear_projection_type='linear',
- solver_perference=self.solver_perference))
+ LinearProjectionStrategyGenerator(
+ op_data_mapping,
+ self.device_mesh,
+ linear_projection_type="linear",
+ solver_perference=self.solver_perference,
+ )
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -163,28 +172,34 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=input_meta_data,
- logical_shape=input_logical_shape)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape[::-1])
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=input_meta_data,
+ logical_shape=input_logical_shape,
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape[::-1],
+ )
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=output_meta_data,
+ logical_shape=output_logical_shape,
+ )
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if 'bias' in self.named_parameters is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if "bias" in self.named_parameters is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
@@ -194,14 +209,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
2. the input and output sharding specs are updated to physical shape.
"""
# switch the dimensions of the transposed weight
- strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
+ strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight")
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
- input_name=str(self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
@@ -215,7 +230,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -223,10 +239,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.node.args[0]._meta_data,
+ logical_shape=input_logical_shape,
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -234,10 +252,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
else:
data_type = OperationDataType.ARG
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data,
- logical_shape=self.node.args[1]._meta_data.shape[::-1])
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]),
+ type=data_type,
+ data=self.node.args[1]._meta_data,
+ logical_shape=self.node.args[1]._meta_data.shape[::-1],
+ )
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
@@ -249,27 +269,28 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
+ if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
- physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
- type=data_type,
- data=self.node.kwargs["bias"]._meta_data)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
# switch the dimensions of the transposed weight
- strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
- weight_name=str(self.node.args[1]))
+ strategy = _update_sharding_spec_for_transposed_weight_for_linear(
+ strategy=strategy, weight_name=str(self.node.args[1])
+ )
# create multiple sharding strategies for the inputs
# as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
- input_name=str(self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
index fa51114a5c94..4fab5f7f05eb 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
@@ -16,7 +16,7 @@
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
@@ -37,6 +37,7 @@ class MatMulType(Enum):
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
+
DOT = 0
MM = 1
MV = 2
@@ -92,26 +93,26 @@ def __init__(self) -> None:
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
- input_shape = mapping_copy['input']
- other_shape = mapping_copy['other']
+ input_shape = mapping_copy["input"]
+ other_shape = mapping_copy["other"]
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
- self.padded_dim_mapping['input'] = -2
- self.padded_dim_mapping['output'] = -2
+ self.padded_dim_mapping["input"] = -2
+ self.padded_dim_mapping["output"] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
- self.padded_dim_mapping['other'] = -1
- self.padded_dim_mapping['output'] = -1
+ self.padded_dim_mapping["other"] = -1
+ self.padded_dim_mapping["output"] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
- input_op_data = op_data_mapping['input']
- other_op_data = op_data_mapping['other']
+ op_data_mapping["input"]
+ op_data_mapping["other"]
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
@@ -131,7 +132,7 @@ def _remove_padded_dim(key, strategy):
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
- assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
+ assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}"
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
@@ -142,15 +143,15 @@ def _remove_padded_dim(key, strategy):
strategy_copy = strategy.clone()
# only one of input and other will be padded
- if 'input' in self.padded_dim_mapping:
- _remove_padded_dim('input', strategy_copy)
- _remove_padded_dim('output', strategy_copy)
- elif 'other' in self.padded_dim_mapping:
- _remove_padded_dim('other', strategy_copy)
- _remove_padded_dim('output', strategy_copy)
+ if "input" in self.padded_dim_mapping:
+ _remove_padded_dim("input", strategy_copy)
+ _remove_padded_dim("output", strategy_copy)
+ elif "other" in self.padded_dim_mapping:
+ _remove_padded_dim("other", strategy_copy)
+ _remove_padded_dim("output", strategy_copy)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
pass
return strategies
@@ -167,8 +168,8 @@ def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
# get shapes
- input_shape = mapping_copy['input']
- other_shape = mapping_copy['other']
+ input_shape = mapping_copy["input"]
+ other_shape = mapping_copy["other"]
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
@@ -179,16 +180,16 @@ def apply(self, shape_mapping: Dict[str, List[int]]):
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
- self.broadcast_dim_info['input'] = input_broadcast_dim_info
- self.broadcast_dim_info['other'] = other_broadcast_dim_info
+ self.broadcast_dim_info["input"] = input_broadcast_dim_info
+ self.broadcast_dim_info["other"] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
- mapping_copy['input'] = input_shape
- mapping_copy['other'] = other_shape
+ mapping_copy["input"] = input_shape
+ mapping_copy["other"] = other_shape
return mapping_copy
@@ -216,17 +217,18 @@ def _remove_sharding_on_broadcast_dim(key, strategy):
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
- physical_shape=tensor_shape_before_broadcast)
+ physical_shape=tensor_shape_before_broadcast,
+ )
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
- _remove_sharding_on_broadcast_dim('input', strategy_copy)
- _remove_sharding_on_broadcast_dim('other', strategy_copy)
+ _remove_sharding_on_broadcast_dim("input", strategy_copy)
+ _remove_sharding_on_broadcast_dim("other", strategy_copy)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
pass
return strategies
@@ -241,20 +243,20 @@ def __init__(self) -> None:
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
- self.batch_dims_before_view = list(mapping_copy['input'][:-2])
+ self.batch_dims_before_view = list(mapping_copy["input"][:-2])
# get shapes
- input_shape = shape_mapping['input']
- other_shape = shape_mapping['other']
+ input_shape = shape_mapping["input"]
+ other_shape = shape_mapping["other"]
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
- mapping_copy['input'] = input_shape
- mapping_copy['other'] = other_shape
- mapping_copy['output'] = output_shape
+ mapping_copy["input"] = input_shape
+ mapping_copy["other"] = other_shape
+ mapping_copy["output"] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
@@ -291,11 +293,11 @@ def _update_sharding_spec(key, strategy, physical_batch_dim):
# create a new strategy
strategy_copy = strategy.clone()
try:
- _update_sharding_spec('input', strategy_copy, i)
- _update_sharding_spec('other', strategy_copy, i)
- _update_sharding_spec('output', strategy_copy, i)
+ _update_sharding_spec("input", strategy_copy, i)
+ _update_sharding_spec("other", strategy_copy, i)
+ _update_sharding_spec("output", strategy_copy, i)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
continue
return strategies
@@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
3. reshape to 3 dimensions
"""
- shape_mapping = {'input': input_shape, 'other': other_shape}
+ shape_mapping = {"input": input_shape, "other": other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
- input_shape = shape_mapping.get('input', None)
- other_shape = shape_mapping.get('other', None)
- output_shape = shape_mapping.get('output', None)
+ input_shape = shape_mapping.get("input", None)
+ other_shape = shape_mapping.get("other", None)
+ output_shape = shape_mapping.get("output", None)
return input_shape, other_shape, output_shape
@@ -364,7 +366,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -372,7 +375,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
- MatMulType.BMM: self._get_logical_shape_for_bmm
+ MatMulType.BMM: self._get_logical_shape_for_bmm,
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
@@ -390,20 +393,26 @@ def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_
output_logical_shape = torch.Size(output_logical_shape)
# create op data
- input_op_data = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.input_meta_data,
- logical_shape=input_logical_shape)
- other_op_data = OperationData(name=str(self.node.args[1]),
- type=OperationDataType.ARG,
- data=self.other_meta_data,
- logical_shape=other_logical_shape)
- output_op_data = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=self.output_meta_data,
- logical_shape=output_logical_shape)
-
- mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
+ input_op_data = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.input_meta_data,
+ logical_shape=input_logical_shape,
+ )
+ other_op_data = OperationData(
+ name=str(self.node.args[1]),
+ type=OperationDataType.ARG,
+ data=self.other_meta_data,
+ logical_shape=other_logical_shape,
+ )
+ output_op_data = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=self.output_meta_data,
+ logical_shape=output_logical_shape,
+ )
+
+ mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
@@ -460,9 +469,11 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
dim_partition_dict[0] = shard
# re-init the sharding spec
- input_sharding_spec.__init__(input_sharding_spec.device_mesh,
- entire_shape=input_physical_shape,
- dim_partition_dict=dim_partition_dict)
+ input_sharding_spec.__init__(
+ input_sharding_spec.device_mesh,
+ entire_shape=input_physical_shape,
+ dim_partition_dict=dim_partition_dict,
+ )
return strategy
else:
return strategy
@@ -481,7 +492,8 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
recovered_stragies.extend(output)
else:
raise TypeError(
- f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
+ f"Found unexpected output type {type(output)} from the recover method of BmmTransform"
+ )
strategies = recovered_stragies
for index, strategies in enumerate(strategies):
strategies.name = f"{strategies.name}_{index}"
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
index b4b7b0e794d1..d2bad39dcbb9 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
@@ -8,7 +8,6 @@
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
- OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
@@ -23,21 +22,23 @@
class NodeHandler(ABC):
- '''
+ """
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
- '''
-
- def __init__(self,
- node: Node,
- device_mesh: DeviceMesh,
- strategies_vector: StrategiesVector,
- shard_option: ShardOption = ShardOption.STANDARD,
- solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
+ """
+
+ def __init__(
+ self,
+ node: Node,
+ device_mesh: DeviceMesh,
+ strategies_vector: StrategiesVector,
+ shard_option: ShardOption = ShardOption.STANDARD,
+ solver_perference: SolverPerference = SolverPerference.STANDARD,
+ ) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
@@ -68,8 +69,9 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
- assert hasattr(node, 'strategies_vector'), \
- f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
+ assert hasattr(
+ node, "strategies_vector"
+ ), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost."
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
@@ -80,10 +82,10 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
resharding_costs[node] = []
def _compute_resharding_cost(
- prev_sharding_spec: Union[ShardingSpec,
- List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
- List[ShardingSpec]],
- data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
+ prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
+ current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
+ data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
+ ) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
@@ -94,30 +96,35 @@ def _compute_resharding_cost(
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
- prev_sharding_spec, current_sharding_spec)
-
- resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
- bwd=consistency_cost["backward"] * size_per_elem_bytes,
- total=consistency_cost["total"] * size_per_elem_bytes)
+ prev_sharding_spec, current_sharding_spec
+ )
+
+ resharding_cost = TrainCycleItem(
+ fwd=consistency_cost["forward"] * size_per_elem_bytes,
+ bwd=consistency_cost["backward"] * size_per_elem_bytes,
+ total=consistency_cost["total"] * size_per_elem_bytes,
+ )
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
- raise ValueError(f'Unsupported data type {type(data)}')
+ raise ValueError(f"Unsupported data type {type(data)}")
else:
- assert isinstance(prev_sharding_spec, (tuple, list)), \
- f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
- or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
+ assert isinstance(
+ prev_sharding_spec, (tuple, list)
+ ), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
+ or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}"
fwd_cost = 0
bwd_cost = 0
total_cost = 0
- for index, (prev_sharding_spec_item,
- current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
- current_sharding_spec)):
- item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
- data[index])
+ for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate(
+ zip(prev_sharding_spec, current_sharding_spec)
+ ):
+ item_cost = _compute_resharding_cost(
+ prev_sharding_spec_item, current_sharding_spec_item, data[index]
+ )
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
@@ -138,17 +145,17 @@ def get_target_function(self) -> callable:
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
- if self.node.op in ('placeholder', 'get_attr', 'output'):
+ if self.node.op in ("placeholder", "get_attr", "output"):
return None
- if self.node.op == 'call_module':
+ if self.node.op == "call_module":
target = self.node.graph.owning_module.get_submodule(self.node.target)
- elif self.node.op == 'call_function':
+ elif self.node.op == "call_function":
target = self.node.target
- elif self.node.op == 'call_method':
+ elif self.node.op == "call_method":
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else:
- raise ValueError(f'Unsupported node type: {self.node.op}')
+ raise ValueError(f"Unsupported node type: {self.node.op}")
return target
@@ -221,7 +228,6 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
"""
Define which generators should be used by this NodeHandler object.
"""
- pass
@abstractmethod
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -244,7 +250,6 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
- pass
class MetaInfoNodeHandler(NodeHandler):
@@ -278,19 +283,19 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
else:
logger = get_dist_logger()
- logger.warning(f'The target function {target} is not patched yet, ')
+ logger.warning(f"The target function {target} is not patched yet, ")
return self.strategies_vector
class ModuleHandler(NodeHandler):
-
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# set attributes to access module parameters for convenience
- assert self.node.graph.owning_module is not None, \
- f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
+ assert (
+ self.node.graph.owning_module is not None
+ ), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object."
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
named_buffers = list(module.named_buffers(recurse=False))
@@ -333,6 +338,6 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
else:
logger = get_dist_logger()
- logger.warning(f'The target function {target} is not patched yet')
+ logger.warning(f"The target function {target} is not patched yet")
return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
index 4e71ccba95a7..facf19560596 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
-__all__ = ['NormPoolingHandler']
+__all__ = ["NormPoolingHandler"]
@operator_registry.register(torch.nn.MaxPool1d)
@@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
index ed120a8c3d6d..89906a205e87 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
@@ -8,7 +8,7 @@
from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator
-__all__ = ['OutputHandler']
+__all__ = ["OutputHandler"]
class OutputHandler(NodeHandler):
@@ -16,8 +16,9 @@ class OutputHandler(NodeHandler):
A OutputHandler which deals with the sharding strategies for Output Node.
"""
- def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
- output_option: str) -> None:
+ def __init__(
+ self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str
+ ) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option
@@ -35,11 +36,11 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
for index, input_node in enumerate(self.predecessor_node):
input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
- name_key = f'input_{index}'
+ name_key = f"input_{index}"
mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data)
- assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
+ assert len(output_meta_data) > 0, f"Output node {self.node} has no input node."
if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0]
else:
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
index 91e4a5105a08..75f07168e47b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import PermuteGenerator, StrategyGenerator
-__all__ = ['PermuteHandler']
+__all__ = ["PermuteHandler"]
@operator_registry.register(torch.Tensor.permute)
@@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = []
- if self.node.op == 'call_method':
+ if self.node.op == "call_method":
# torch.Tensor.permute (input, *dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data)
else:
- assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
+ assert isinstance(arg, int), "The argument in permute node should be either type of Node or int."
permute_dims.append(arg)
else:
# torch.permute (input, dims)
@@ -51,8 +51,8 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
permute_dims.extend(arg._meta_data)
else:
assert isinstance(
- arg,
- (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
+ arg, (tuple, list)
+ ), "The argument in permute node should be type of Node, Tuple[int] or List[int]."
permute_dims.extend(arg)
num_dims = self.node._meta_data.dim()
@@ -61,7 +61,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
if permute_dims[i] < 0:
permute_dims[i] += num_dims
- physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
+ physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -69,7 +69,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"permute_dims": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
index e4f40fc935a4..461bc2935780 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
@@ -8,7 +8,7 @@
from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator
-__all__ = ['PlaceholderHandler']
+__all__ = ["PlaceholderHandler"]
class PlaceholderHandler(NodeHandler):
@@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler):
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
"""
- def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
- placeholder_option: str) -> None:
+ def __init__(
+ self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str
+ ) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option
@@ -25,7 +26,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
+ PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
index 730a90d74cf8..f663fc9695d3 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
@@ -1,11 +1,9 @@
class Registry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
@@ -18,7 +16,7 @@ def wrapper(func):
return wrapper
def get(self, source):
- assert source in self.store, f'{source} not found in the {self.name} registry'
+ assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
@@ -26,4 +24,4 @@ def has(self, source):
return source in self.store
-operator_registry = Registry('operator')
+operator_registry = Registry("operator")
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
index 743a1f90eaaf..6e883ea64736 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
-__all__ = ['SoftmaxHandler']
+__all__ = ["SoftmaxHandler"]
@operator_registry.register(torch.nn.Softmax)
@@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
- softmax_dim = self.node.kwargs['dim']
+ softmax_dim = self.node.kwargs["dim"]
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
- physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
+ physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -49,7 +49,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
index 653d158b7c36..4c32529a5d5b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import SplitGenerator, StrategyGenerator
-__all__ = ['SplitHandler']
+__all__ = ["SplitHandler"]
@operator_registry.register(torch.Tensor.split)
@@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
split_dim = self.node.args[2]
else:
if self.node.kwargs:
- split_dim = self.node.kwargs['dim']
+ split_dim = self.node.kwargs["dim"]
else:
split_dim = 0
@@ -48,7 +48,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
split_dim += num_dims
split_info = (split_size, split_dim)
- physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
+ physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -56,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
index db1f31521c86..1fc7f613716b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
@@ -29,11 +29,31 @@
from .where_generator import WhereGenerator
__all__ = [
- 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
- 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
- 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
- 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator',
- 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator',
- 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator',
- 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator'
+ "StrategyGenerator",
+ "DotProductStrategyGenerator",
+ "MatVecStrategyGenerator",
+ "LinearProjectionStrategyGenerator",
+ "BatchedMatMulStrategyGenerator",
+ "ConvStrategyGenerator",
+ "UnaryElementwiseGenerator",
+ "BatchNormStrategyGenerator",
+ "GetItemStrategyGenerator",
+ "TensorStrategyGenerator",
+ "TensorTupleStrategyGenerator",
+ "LayerNormGenerator",
+ "PlaceholderGenerator",
+ "OutputGenerator",
+ "WhereGenerator",
+ "NormalPoolStrategyGenerator",
+ "BinaryElementwiseStrategyGenerator",
+ "GetattrGenerator",
+ "TensorConstructorGenerator",
+ "EmbeddingStrategyGenerator",
+ "SumGenerator",
+ "SoftmaxGenerator",
+ "ViewGenerator",
+ "PermuteGenerator",
+ "TransposeGenerator",
+ "SplitGenerator",
+ "DefaultReshapeGenerator",
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
index 416dc9c29cad..9c766b1014c8 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
@@ -14,7 +14,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['BatchNormStrategyGenerator']
+__all__ = ["BatchNormStrategyGenerator"]
class BatchNormStrategyGenerator(StrategyGenerator):
@@ -30,28 +30,31 @@ class BatchNormStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
@@ -69,23 +72,24 @@ def update_compute_cost(self, strategy: ShardingStrategy):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output"),
- 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
- 'running_var': self._compute_size_in_bytes(strategy, "running_var"),
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
+ "running_mean": self._compute_size_in_bytes(strategy, "running_mean"),
+ "running_var": self._compute_size_in_bytes(strategy, "running_var"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum(
- [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
+ [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
+ )
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)
@@ -93,36 +97,29 @@ def update_memory_cost(self, strategy: ShardingStrategy):
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum(
- [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
+ [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
+ )
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost,
- buffer=fwd_buffer_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost,
+ parameter=fwd_parameter_cost + bwd_parameter_cost,
+ buffer=fwd_buffer_cost,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}"
dim_partition_dict_mapping = {
- "input": {
- 1: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0]
- },
- "output": {
- 1: [mesh_dim_0]
- },
- "running_mean": {
- 0: [mesh_dim_0]
- },
- "running_var": {
- 0: [mesh_dim_0]
- },
+ "input": {1: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0]},
+ "output": {1: [mesh_dim_0]},
+ "running_mean": {0: [mesh_dim_0]},
+ "running_var": {0: [mesh_dim_0]},
"num_batches_tracked": {},
}
if self.has_bias:
@@ -132,29 +129,21 @@ def split_input_channel(self, mesh_dim_0):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 1: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "output": {
- 1: [mesh_dim_0, mesh_dim_1]
- },
- "running_mean": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "running_var": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {1: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
+ "output": {1: [mesh_dim_0, mesh_dim_1]},
+ "running_mean": {0: [mesh_dim_0, mesh_dim_1]},
+ "running_var": {0: [mesh_dim_0, mesh_dim_1]},
"num_batches_tracked": {},
}
if self.has_bias:
@@ -164,13 +153,15 @@ def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x R'
+ name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
@@ -186,21 +177,19 @@ def non_split(self):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
- "output": {
- 0: [mesh_dim_0]
- },
+ "output": {0: [mesh_dim_0]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
@@ -218,27 +207,26 @@ def split_input_batch(self, mesh_dim_0):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
@@ -256,19 +244,22 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
@@ -304,20 +295,23 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0],
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
# RS = RS x S
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
index d27cc046eaf3..c7da0034ec3b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
@@ -14,7 +14,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['BinaryElementwiseStrategyGenerator']
+__all__ = ["BinaryElementwiseStrategyGenerator"]
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
@@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- assert len(self.op_data) == 3, \
- f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
+ assert (
+ len(self.op_data) == 3
+ ), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}"
for name, op_data in self.op_data.items():
if not isinstance(op_data.data, (torch.Tensor, int, float)):
- raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
+ raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.")
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# to be twice the fwd compute cost
fwd_compute_cost = reduce(operator.mul, shape)
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# all input, output and outputs have the same shape
- shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# we approximate the fwd memory cost to be the output
# and the backward memory cost to be grad of input and other
- input_bytes = self._compute_size_in_bytes(strategy, 'input')
- other_bytes = self._compute_size_in_bytes(strategy, 'other')
- output_bytes = self._compute_size_in_bytes(strategy, 'output')
+ input_bytes = self._compute_size_in_bytes(strategy, "input")
+ other_bytes = self._compute_size_in_bytes(strategy, "other")
+ output_bytes = self._compute_size_in_bytes(strategy, "output")
fwd_memory_cost = MemoryCost(activation=output_bytes)
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
@@ -66,7 +67,7 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
- dim_size = len(self.op_data['output'].logical_shape)
+ dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
@@ -86,21 +87,22 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
- dim_partition_dict_mapping = dict(input=dim_partition_dict,
- other=dim_partition_dict,
- output=dim_partition_dict)
+ dim_partition_dict_mapping = dict(
+ input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict
+ )
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
- sharding_seq = sharding_spec_mapping['input'].sharding_sequence
- name = f'{sharding_seq} = {sharding_seq} {sharding_seq}'
+ sharding_seq = sharding_spec_mapping["input"].sharding_sequence
+ name = f"{sharding_seq} = {sharding_seq} {sharding_seq}"
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
index e605a68a326b..5208f61543bb 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
@@ -1,11 +1,9 @@
import copy
import operator
-import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
CommType,
MemoryCost,
ShardingStrategy,
@@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
@@ -76,14 +77,14 @@ def update_compute_cost(self, strategy: ShardingStrategy):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
@@ -100,26 +101,20 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- 1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0]},
+ "other": {1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
@@ -132,7 +127,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@@ -140,7 +136,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -148,38 +145,41 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
@@ -196,7 +196,8 @@ def split_input_batch(self, mesh_dim_0):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -204,42 +205,45 @@ def split_input_batch(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
- "other": {
- 0: [mesh_dim_1]
- },
+ "other": {0: [mesh_dim_1]},
"output": {
0: [mesh_dim_0],
},
@@ -254,7 +258,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
@@ -263,7 +268,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -271,7 +277,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
@@ -279,23 +286,27 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
+ name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
@@ -322,23 +333,27 @@ def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"output": output_comm_action, "input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
- name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
+ name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R"
dim_partition_dict_mapping = {
"input": {
@@ -360,17 +375,20 @@ def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
@@ -395,17 +413,20 @@ def split_weight_out_channel(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x RR'
+ name = f"RR = RR x RR"
dim_partition_dict_mapping = {
"input": {},
@@ -418,13 +439,13 @@ def non_split(self):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
dim_partition_dict_mapping = {
"input": {
@@ -447,14 +468,16 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
@@ -464,23 +487,27 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
+ name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1],
@@ -501,17 +528,20 @@ def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
"other": {
@@ -535,13 +565,16 @@ def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
index 82a04ab52e73..385a8886f231 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
@@ -1,11 +1,9 @@
import copy
import operator
-import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
CommType,
MemoryCost,
ShardingStrategy,
@@ -27,16 +25,16 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
- '''
+ """
# TODO: estimate the embedding computation cost as sparse operation
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
@@ -55,9 +53,9 @@ def update_compute_cost(self, strategy: ShardingStrategy):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -75,14 +73,15 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
- name = f'RR = R x RR'
+ name = f"RR = R x RR"
dim_partition_dict_mapping = {
"input": {},
@@ -92,18 +91,16 @@ def non_split(self):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
@@ -118,7 +115,8 @@ def split_input(self, mesh_dim_0):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -126,17 +124,20 @@ def split_input(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
@@ -159,7 +160,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@@ -167,7 +169,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -175,22 +178,23 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
@@ -207,7 +211,8 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -215,17 +220,20 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
@@ -245,17 +253,20 @@ def split_embedding_dim(self, mesh_dim_0):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
@@ -275,13 +286,16 @@ def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
index bbeb9a639c83..cc8d5771f28e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
@@ -10,7 +10,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['GetattrGenerator']
+__all__ = ["GetattrGenerator"]
class GetattrGenerator(StrategyGenerator):
@@ -26,10 +26,10 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
@@ -47,7 +47,7 @@ def update_memory_cost(self, strategy: ShardingStrategy):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
- dim_size = len(self.op_data['output'].logical_shape)
+ dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
@@ -78,7 +78,8 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
index 0aeb2e0d4079..6f01d9cc7f8e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
@@ -1,19 +1,13 @@
import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.logging import get_dist_logger
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import FollowingStrategyGenerator
-__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
+__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"]
class GetItemStrategyGenerator(FollowingStrategyGenerator):
@@ -35,12 +29,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -58,27 +52,29 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
- '''
+ """
Deal with case 1 and 2.
- '''
+ """
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
- getitem_index = self.op_data['index'].data
+ getitem_index = self.op_data["index"].data
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
try:
logger = get_dist_logger()
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = copy.deepcopy(
- strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
+ strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
+ )
int_index = False
if isinstance(getitem_index, int):
@@ -120,9 +116,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
except ShardingSpecException as e:
logger.debug(e)
continue
@@ -137,9 +135,9 @@ def collate_strategies(self) -> List[ShardingStrategy]:
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
- '''
+ """
Deal with case 3.
- '''
+ """
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -158,13 +156,15 @@ def collate_strategies(self) -> List[ShardingStrategy]:
sharding_spec_mapping["input"] = sharding_spec_for_input
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
- input_sharding_info += f'{sharding_spec.sharding_sequence}, '
+ input_sharding_info += f"{sharding_spec.sharding_sequence}, "
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
index 65b173bbf65d..e5b7e6f25d4d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
@@ -18,7 +18,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['LayerNormGenerator']
+__all__ = ["LayerNormGenerator"]
class LayerNormGenerator(StrategyGenerator):
@@ -31,21 +31,21 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
- input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
+ input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
@@ -62,18 +62,18 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
@@ -90,8 +90,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -120,7 +121,8 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
@@ -128,12 +130,15 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
communication_action_mapping["bias"] = bias_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -155,7 +160,7 @@ def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimensio
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x R'
+ name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
@@ -168,14 +173,16 @@ def non_split(self):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
index aa1581b99e0f..fb182afb9175 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
@@ -1,5 +1,4 @@
import operator
-from ast import arg
from functools import reduce
from typing import List
@@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- size_mapping['bias'] = bias_size
+ size_mapping["bias"] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
@@ -41,45 +40,47 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
- bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
+ bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + 0)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
-
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
return compute_cost
@ignore_sharding_exception
def no_split(self):
- name = f'R = R dot R'
- dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
+ name = f"R = R dot R"
+ dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
- name = f'R = S{mesh_dim} dot S{mesh_dim}'
+ name = f"R = S{mesh_dim} dot S{mesh_dim}"
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
@@ -87,14 +88,17 @@ def split_one_dim(self, mesh_dim):
# get communication action
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -112,19 +116,18 @@ def collate_strategies(self) -> List[ShardingStrategy]:
class MatVecStrategyGenerator(MatMulStrategyGenerator):
-
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
return compute_cost
@ignore_sharding_exception
@@ -133,67 +136,69 @@ def no_split(self):
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
- dim_partition_dict['bias'] = {}
+ dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
- name = f'S{mesh_dim}R = S{mesh_dim}R x R'
+ name = f"S{mesh_dim}R = S{mesh_dim}R x R"
# get sharding spec
dim_partition_dict = {
- "input": {
- 0: [mesh_dim]
- },
+ "input": {0: [mesh_dim]},
"other": {},
- "output": {
- 0: [mesh_dim]
- },
+ "output": {0: [mesh_dim]},
}
if self.has_bias:
- dim_partition_dict['bias'] = {}
+ dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=2)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=2,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -209,12 +214,13 @@ def collate_strategies(self) -> List[ShardingStrategy]:
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-
- def __init__(self,
- operation_data_mapping,
- device_mesh,
- linear_projection_type='linear',
- solver_perference=SolverPerference.STANDARD):
+ def __init__(
+ self,
+ operation_data_mapping,
+ device_mesh,
+ linear_projection_type="linear",
+ solver_perference=SolverPerference.STANDARD,
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
@@ -224,17 +230,17 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
def dp_strategies(self) -> List[ShardingStrategy]:
@@ -301,28 +307,21 @@ def collate_strategies(self) -> List[ShardingStrategy]:
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- -1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0]},
+ "other": {-1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], -1: [mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -333,75 +332,75 @@ def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
- communication_action_mapping['input'] = input_comm_action
- communication_action_mapping['other'] = other_comm_action
+ communication_action_mapping["input"] = input_comm_action
+ communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
# get sharding spec mapping
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0], -1: [mesh_dim_1]},
+ "other": {0: [mesh_dim_1]},
"bias": {},
- "output": {
- 0: [mesh_dim_0]
- },
+ "output": {0: [mesh_dim_0]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -412,66 +411,64 @@ def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
- communication_action_mapping['other'] = other_comm_action
- communication_action_mapping['output'] = output_comm_action
+ communication_action_mapping["other"] = other_comm_action
+ communication_action_mapping["output"] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
+ name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
# get sharding specs
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_1]
- },
+ "input": {-1: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0], -1: [mesh_dim_1]},
+ "bias": {-1: [mesh_dim_1]},
+ "output": {-1: [mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
@@ -482,34 +479,34 @@ def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping["input"] = input_comm_action
- communication_action_mapping['output'] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["output"] = output_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
- name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
+ name = f"RR = RS{mesh_dim} x S{mesh_dim}R"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim]
- },
- "other": {
- 0: [mesh_dim]
- },
+ "input": {-1: [mesh_dim]},
+ "other": {0: [mesh_dim]},
"bias": {},
"output": {},
}
@@ -520,32 +517,29 @@ def recompute_split_both_contract(self, mesh_dim):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
- communication_action_mapping['output'] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["output"] = output_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
- name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
+ name = f"RS{mesh_dim} = RR x RS{mesh_dim}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
- "other": {
- -1: [mesh_dim]
- },
- "bias": {
- -1: [mesh_dim]
- },
- "output": {
- -1: [mesh_dim]
- },
+ "other": {-1: [mesh_dim]},
+ "bias": {-1: [mesh_dim]},
+ "output": {-1: [mesh_dim]},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
@@ -554,93 +548,94 @@ def split_rhs_space_only(self, mesh_dim):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
- communication_action_mapping['input'] = input_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["input"] = input_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
+ name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {-1: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
"output": {},
}
@@ -652,32 +647,29 @@ def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.AFTER)
- communication_action_mapping['output'] = output_comm_action
+ comm_type=CommType.AFTER,
+ )
+ communication_action_mapping["output"] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
- "other": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
+ "other": {-1: [mesh_dim_0, mesh_dim_1]},
+ "bias": {-1: [mesh_dim_0, mesh_dim_1]},
+ "output": {-1: [mesh_dim_0, mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
@@ -687,20 +679,23 @@ def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
# get communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['input'] = input_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["input"] = input_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x RR'
+ name = f"RR = RR x RR"
# get sharding spec
dim_partition_dict_mapping = {
@@ -717,22 +712,24 @@ def non_split(self):
# get communication action
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
- input_data = self.op_data['input']
- other_data = self.op_data['other']
+ input_data = self.op_data["input"]
+ other_data = self.op_data["other"]
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
if self.has_bias:
- bias_data = self.op_data['bias']
+ bias_data = self.op_data["bias"]
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
@@ -757,37 +754,38 @@ def __init__(self, *args, **kwargs):
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
- dim_partition_dict['output'].pop(0, None)
+ dim_partition_dict["output"].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
- keys = list(dim_partition_dict['output'].keys())
+ keys = list(dim_partition_dict["output"].keys())
for key in keys:
- val = dim_partition_dict['output'].pop(key)
+ val = dim_partition_dict["output"].pop(key)
temp_dim_partition[key - 1] = val
- dim_partition_dict['output'].update(temp_dim_partition)
+ dim_partition_dict["output"].update(temp_dim_partition)
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
- if 'bias' in self.op_data:
- bias_op_data = self.op_data['bias']
+ if "bias" in self.op_data:
+ bias_op_data = self.op_data["bias"]
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
- self.op_data['output'].data.shape)
+ fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce(
+ operator.mul, self.op_data["output"].data.shape
+ )
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim):
- name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
+ name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}"
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
@@ -799,30 +797,27 @@ def split_one_batch_dim(self, mesh_dim):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- }
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -832,35 +827,28 @@ def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
+ name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0]
- },
- "bias": {
- 0: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- }
+ "input": {0: [mesh_dim_0], 1: [mesh_dim_1]},
+ "other": {0: [mesh_dim_0]},
+ "bias": {0: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -869,46 +857,40 @@ def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
- communication_action_mapping['other'].arg_index += 1
+ communication_action_mapping["other"].arg_index += 1
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- },
- "bias": {
- 1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- }
+ "input": {0: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0], 2: [mesh_dim_1]},
+ "bias": {1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 2: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -917,43 +899,41 @@ def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['input'] = input_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["input"] = input_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.BEFORE)
- communication_action_mapping['bias'] = bias_comm_action
+ comm_type=CommType.BEFORE,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
- communication_action_mapping['input'].arg_index += 1
+ communication_action_mapping["input"].arg_index += 1
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0], 2: [mesh_dim_1]},
+ "other": {0: [mesh_dim_0], 1: [mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0],
- }
+ },
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -962,24 +942,28 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
- communication_action_mapping['output'] = output_comm_action
+ comm_type=CommType.AFTER,
+ )
+ communication_action_mapping["output"] = output_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
-
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
index b7db42f8f67e..b307e38b5b6d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
@@ -21,28 +21,31 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into Pool op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
- '''
+ """
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
@@ -61,8 +64,8 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -88,12 +91,16 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
+ name = (
+ f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
+ )
communication_action_mapping = {}
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
index 69d1642d4f80..33fb1ac5c5be 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
@@ -12,7 +12,7 @@
from .strategy_generator import OutputStrategyGenerator
-__all__ = ['OutputGenerator']
+__all__ = ["OutputGenerator"]
class OutputGenerator(OutputStrategyGenerator):
@@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator):
OutputGenerator is a generic class to generate strategies for Output Node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_nodes: List[Node], output_option: str):
+ def __init__(
+ self,
+ operation_data_mapping: Dict[str, OperationData],
+ device_mesh: DeviceMesh,
+ predecessor_nodes: List[Node],
+ output_option: str,
+ ):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
@@ -33,9 +38,9 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
fwd_mem_cost = MemoryCost(activation=0, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
@@ -65,16 +70,18 @@ def replica_strategy(self) -> List[ShardingStrategy]:
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
- dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
+ dim_partition_dict_mapping["output"] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Output'
+ name = "Replica Output"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
@@ -82,19 +89,15 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
- output_op_data = self.op_data['output']
+ output_op_data = self.op_data["output"]
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
- "output": [{
- 0: mesh_list
- }] * length,
+ "output": [{0: mesh_list}] * length,
}
else:
dim_partition_dict_mapping = {
- "output": {
- 0: mesh_list
- },
+ "output": {0: mesh_list},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
@@ -103,19 +106,21 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Distributed Output'
+ name = "Distributed Output"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
- if self.output_option == 'replicated':
+ if self.output_option == "replicated":
strategy_list.append(self.replica_strategy())
- elif self.output_option == 'distributed':
+ elif self.output_option == "distributed":
strategy_list.append(self.distributed_strategy(mesh_list))
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
index 779a7ced93bb..df0862a396d2 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
@@ -10,7 +10,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['PlaceholderGenerator']
+__all__ = ["PlaceholderGenerator"]
class PlaceholderGenerator(StrategyGenerator):
@@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator):
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- placeholder_option: str):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
@@ -31,10 +32,10 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
@@ -58,11 +59,13 @@ def replica_placeholder(self) -> ShardingStrategy:
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Placeholder'
+ name = "Replica Placeholder"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -71,29 +74,31 @@ def distributed_placeholder(self, mesh_list) -> ShardingStrategy:
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
- "output": {
- 0: mesh_list
- },
+ "output": {0: mesh_list},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Distributed Placeholder'
+ name = "Distributed Placeholder"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
- if self.placeholder_option == 'distributed':
+ if self.placeholder_option == "distributed":
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
- assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
+ assert (
+ self.placeholder_option == "replicated"
+ ), f"placeholder_option {self.placeholder_option} is not supported"
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
index 24f75e352935..48f454553ac7 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
@@ -17,7 +17,7 @@
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
-__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
+__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"]
class ReshapeGenerator(FollowingStrategyGenerator):
@@ -33,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -56,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -77,8 +78,8 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- origin_shape = self.op_data['input'].data.shape
- tgt_shape = self.op_data['tgt_shape'].data
+ origin_shape = self.op_data["input"].data.shape
+ tgt_shape = self.op_data["tgt_shape"].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
@@ -86,8 +87,9 @@ def collate_strategies(self) -> List[ShardingStrategy]:
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
- dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
- reshape_mapping_dict)
+ dim_partition_dict_for_output = infer_output_dim_partition_dict(
+ dim_partition_dict_for_input, reshape_mapping_dict
+ )
else:
dim_partition_dict_for_output = {}
@@ -119,7 +121,8 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
@@ -127,10 +130,10 @@ def collate_strategies(self) -> List[ShardingStrategy]:
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
- target_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=source_spec.entire_shape,
- dim_partition_dict={})
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ target_spec = ShardingSpec(
+ device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
+ )
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -139,9 +142,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -159,7 +164,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- permute_dims = self.op_data['permute_dims'].data
+ permute_dims = self.op_data["permute_dims"].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
@@ -177,9 +182,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -199,7 +206,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
- transpose_dims = self.op_data['transpose_dims'].data
+ transpose_dims = self.op_data["transpose_dims"].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
@@ -221,9 +228,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -242,7 +251,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- split_size, split_dim = self.op_data['split_info'].data
+ split_size, split_dim = self.op_data["split_info"].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
@@ -271,7 +280,8 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
@@ -282,7 +292,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -291,9 +301,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -341,16 +353,17 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
- target_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=source_spec.entire_shape,
- dim_partition_dict={})
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ target_spec = ShardingSpec(
+ device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
+ )
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -358,9 +371,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
index a1ebadd043e2..d4382f9941d2 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
@@ -4,21 +4,9 @@
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_keep_sharding_status,
- detect_reshape_mapping,
- infer_output_dim_partition_dict,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-
-__all__ = ['SoftmaxGenerator']
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
+
+__all__ = ["SoftmaxGenerator"]
class SoftmaxGenerator(FollowingStrategyGenerator):
@@ -30,11 +18,11 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
- '''
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ """
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
@@ -45,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -68,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -80,10 +69,10 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- softmax_dim = self.op_data['softmax_dim'].data
+ softmax_dim = self.op_data["softmax_dim"].data
if softmax_dim in dim_partition_dict_for_input:
- recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
+ dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
@@ -96,9 +85,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
index d42429745c61..7bf2c8cc12a3 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
@@ -39,7 +39,7 @@ def has_bias(self):
"""
A utility method to check for the existence of bias operand for convenience.
"""
- return 'bias' in self.op_data
+ return "bias" in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
@@ -49,8 +49,12 @@ def is_buffer(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER
- def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
- communication_action_mapping: Dict[str, CommSpec]):
+ def get_sharding_strategy(
+ self,
+ name: str,
+ sharding_spec_mapping: Dict[str, ShardingSpec],
+ communication_action_mapping: Dict[str, CommSpec],
+ ):
"""
A factory method to produce a ShardingStrategy object.
@@ -80,24 +84,28 @@ def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
op_data = self.op_data[op_data_name]
def _to_sharding_spec(
- data: any, logical_shape: any,
- dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
+ data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]]
+ ) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
- sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=logical_shape,
- dim_partition_dict=dim_partition_dict)
+ sharding_spec = ShardingSpec(
+ device_mesh=self.device_mesh,
+ entire_shape=logical_shape,
+ dim_partition_dict=dim_partition_dict,
+ )
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
- data, logical_shape, dim_partition_dict):
+ data, logical_shape, dim_partition_dict
+ ):
sharding_spec.append(
- _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
+ _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)
+ )
return sharding_spec
else:
return None
@@ -116,31 +124,41 @@ def replace_op_name_with_op_data(self, mapping: Dict[str, Any]):
results[op_data] = v
return results
- def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
- logical_process_axis: Union[int, List[int]]):
+ def get_communication_spec(
+ self,
+ sharding_spec: ShardingSpec,
+ communication_pattern: CollectiveCommPattern,
+ logical_process_axis: Union[int, List[int]],
+ ):
"""
A factory method to produce a CommSpec object.
"""
- return CommSpec(comm_pattern=communication_pattern,
- sharding_spec=sharding_spec,
- logical_process_axis=logical_process_axis)
-
- def get_communication_action(self,
- sharding_spec: ShardingSpec,
- communication_pattern: CollectiveCommPattern,
- logical_process_axis: Union[int, List[int]],
- comm_type: CommType,
- arg_index: int = -1,
- key_for_kwarg: any = None) -> CommAction:
+ return CommSpec(
+ comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis
+ )
+
+ def get_communication_action(
+ self,
+ sharding_spec: ShardingSpec,
+ communication_pattern: CollectiveCommPattern,
+ logical_process_axis: Union[int, List[int]],
+ comm_type: CommType,
+ arg_index: int = -1,
+ key_for_kwarg: any = None,
+ ) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
- return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
- communication_pattern=communication_pattern,
- logical_process_axis=logical_process_axis),
- comm_type=comm_type,
- arg_index=arg_index,
- key_for_kwarg=key_for_kwarg)
+ return CommAction(
+ comm_spec=self.get_communication_spec(
+ sharding_spec=sharding_spec,
+ communication_pattern=communication_pattern,
+ logical_process_axis=logical_process_axis,
+ ),
+ comm_type=comm_type,
+ arg_index=arg_index,
+ key_for_kwarg=key_for_kwarg,
+ )
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
@@ -155,9 +173,9 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
- comm_cost.fwd += num_ele_in_comm['forward']
- comm_cost.bwd += num_ele_in_comm['backward']
- comm_cost.total += num_ele_in_comm['total']
+ comm_cost.fwd += num_ele_in_comm["forward"]
+ comm_cost.bwd += num_ele_in_comm["backward"]
+ comm_cost.total += num_ele_in_comm["total"]
# check if communication action exists
# if so, loop over each action and compute the cost of each action
@@ -169,8 +187,8 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
- src_spec = comm_spec['src_spec']
- tgt_spec = comm_spec['tgt_spec']
+ src_spec = comm_spec["src_spec"]
+ tgt_spec = comm_spec["tgt_spec"]
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
@@ -187,14 +205,12 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the computation flops.
"""
- pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the memory cost in bytes.
"""
- pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
@@ -212,13 +228,14 @@ def _compute_size_in_bytes_helper(sharding_spec, meta_data):
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
- dtype = getattr(meta_data, 'dtype')
+ dtype = getattr(meta_data, "dtype")
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
- assert isinstance(strategy.sharding_specs[op_data], list), \
- 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
+ assert isinstance(
+ strategy.sharding_specs[op_data], list
+ ), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple."
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
@@ -270,7 +287,6 @@ def validate(self) -> bool:
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
- pass
class FollowingStrategyGenerator(StrategyGenerator):
@@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator):
TODO: remove the original strategy_generator.py after refactoring
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_node: Node):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node
+ ):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
self.predecessor_node = predecessor_node
@@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator):
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_nodes: List[Node]):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node]
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
index a0fbc58d70c0..dcbf34cfd65b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
@@ -4,22 +4,9 @@
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_keep_sharding_status,
- detect_reshape_mapping,
- infer_output_dim_partition_dict,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-__all__ = ['SumGenerator']
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
+
+__all__ = ["SumGenerator"]
class SumGenerator(FollowingStrategyGenerator):
@@ -31,24 +18,24 @@ def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
- compute_cost = TrainCycleItem(fwd=input_size_product,
- bwd=output_size_product,
- total=input_size_product + output_size_product)
+ compute_cost = TrainCycleItem(
+ fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product
+ )
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -66,8 +53,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -78,7 +66,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
+ sum_dims, sum_mapping_dict = self.op_data["sum_info"].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
@@ -90,7 +78,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
- raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
+ raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims")
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
@@ -105,9 +93,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
index 93cfc9eeea53..eea00c2fa064 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
@@ -1,19 +1,10 @@
-import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import StrategyGenerator
-__all__ = ['TensorConstructorGenerator']
+__all__ = ["TensorConstructorGenerator"]
class TensorConstructorGenerator(StrategyGenerator):
@@ -30,10 +21,10 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
@@ -57,11 +48,13 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Tensor Constructor'
+ name = "Replica Tensor Constructor"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
index 39799a67c5a0..943cf3f1f50d 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
@@ -5,7 +5,7 @@
from .strategy_generator import FollowingStrategyGenerator
-__all__ = ['UnaryElementwiseGenerator']
+__all__ = ["UnaryElementwiseGenerator"]
class UnaryElementwiseGenerator(FollowingStrategyGenerator):
@@ -21,12 +21,12 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -44,8 +44,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -69,9 +70,11 @@ def collate_strategies(self) -> List[ShardingStrategy]:
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
index fa941f2cc51d..b27b4f3d4056 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
@@ -10,7 +10,7 @@
from .strategy_generator import StrategyGenerator
-__all__ = ['WhereGenerator']
+__all__ = ["WhereGenerator"]
class WhereGenerator(StrategyGenerator):
@@ -26,14 +26,14 @@ def update_compute_cost(self, strategy: ShardingStrategy):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'condition': self._compute_size_in_bytes(strategy, "condition"),
- 'x': self._compute_size_in_bytes(strategy, "x"),
- 'y': self._compute_size_in_bytes(strategy, "y"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "condition": self._compute_size_in_bytes(strategy, "condition"),
+ "x": self._compute_size_in_bytes(strategy, "x"),
+ "y": self._compute_size_in_bytes(strategy, "y"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -59,7 +59,7 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
"condition": dim_partition,
"x": dim_partition,
"y": dim_partition,
- "output": dim_partition
+ "output": dim_partition,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -67,9 +67,11 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
communication_action_mapping = {}
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -84,9 +86,9 @@ def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_l
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
dimension_length = len(self.op_data["output"].logical_shape)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
index 86f90694e060..5b4ea0afe5f8 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
-__all__ = ['SumHandler']
+__all__ = ["SumHandler"]
@operator_registry.register(torch.Tensor.sum)
@@ -55,7 +55,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
- if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
+ if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
@@ -67,7 +67,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
- physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
+ physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -75,7 +75,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
index 855a2e7612af..c2aa120e8a28 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
@@ -8,7 +8,7 @@
from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator
-__all__ = ['TensorConstructorHandler']
+__all__ = ["TensorConstructorHandler"]
@operator_registry.register(torch.arange)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
index 7a9d37726490..b72d9812f406 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, TransposeGenerator
-__all__ = ['TransposeHandler']
+__all__ = ["TransposeHandler"]
@operator_registry.register(torch.Tensor.transpose)
@@ -48,9 +48,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
- physical_shape_operand = OperationData(name='transpose_dims',
- type=OperationDataType.ARG,
- data=list(transpose_dims))
+ physical_shape_operand = OperationData(
+ name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims)
+ )
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -58,7 +58,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
index 0362de780d7a..cbc873de8223 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
@@ -3,11 +3,11 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
-__all__ = ['UnaryElementwiseHandler']
+__all__ = ["UnaryElementwiseHandler"]
@operator_registry.register(torch.Tensor.to)
@@ -33,9 +33,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output}
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
index 7dff89d1d7a3..56c1d10a167e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
@@ -7,7 +7,7 @@
from .registry import operator_registry
from .strategy import StrategyGenerator, ViewGenerator
-__all__ = ['ViewHandler']
+__all__ = ["ViewHandler"]
@operator_registry.register(torch.Tensor.reshape)
@@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
- physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
+ physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -46,7 +46,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
index 6de2aaafdd01..1856a11100b0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
@@ -1,16 +1,15 @@
import copy
-import operator
from typing import Dict, List
import torch
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
-__all__ = ['WhereHandler']
+__all__ = ["WhereHandler"]
@operator_registry.register(torch.where)
@@ -28,27 +27,28 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_condition_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_x_operand = OperationData(name=str(self.node.args[1]),
- type=OperationDataType.ARG,
- data=self.node.args[1]._meta_data)
- physical_y_operand = OperationData(name=str(self.node.args[2]),
- type=OperationDataType.ARG,
- data=self.node.args[2]._meta_data)
+ physical_condition_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_x_operand = OperationData(
+ name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data
+ )
+ physical_y_operand = OperationData(
+ name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_mapping = {
"condition": physical_condition_operand,
"x": physical_x_operand,
"y": physical_y_operand,
- "output": physical_output
+ "output": physical_output,
}
logical_shape_for_all = self.node._meta_data.shape
logical_mapping = {}
for key, physical_operand in physical_mapping.items():
- logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand,
- logical_shape_for_all)
+ logical_mapping[key] = self.convert_physical_operand_to_logical_operand(
+ physical_operand, logical_shape_for_all
+ )
return logical_mapping, physical_mapping
@@ -64,7 +64,8 @@ def post_process(self, strategy: ShardingStrategy):
logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- logical_sharding_spec, logical_shape, physical_shape)
+ logical_sharding_spec, logical_shape, physical_shape
+ )
strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py
index f0ea502a6f0e..e87872f39c10 100644
--- a/colossalai/auto_parallel/tensor_shard/options.py
+++ b/colossalai/auto_parallel/tensor_shard/options.py
@@ -1,13 +1,14 @@
from dataclasses import dataclass
from enum import Enum
-__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
+__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"]
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
+
STANDARD = 0
DP = 1
TP = 2
@@ -25,6 +26,7 @@ class ShardOption(Enum):
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
+
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
@@ -35,6 +37,7 @@ class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
+
REPLICATED = 0
DISTRIBUTED = 1
@@ -44,6 +47,7 @@ class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
+
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD
diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
index 6af927272437..8e22df64d868 100644
--- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
@@ -10,7 +10,6 @@
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import (
- BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
@@ -18,13 +17,14 @@
RESHAPE_METHOD_OP,
)
-__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
+__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"]
class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
+
INPUT = 0
ARG = 1
PARAM = 2
@@ -43,6 +43,7 @@ class OperationData:
data (Any): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
+
name: str
type: OperationDataType
data: Any
@@ -69,13 +70,13 @@ def _infer_logical_shape(data: any):
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str:
- return f'OperationData(name={self.name}, type={self.type})'
+ return f"OperationData(name={self.name}, type={self.type})"
def __eq__(self, other) -> bool:
return other.name == self.name
def __hash__(self) -> int:
- return hash(f'{self.name}')
+ return hash(f"{self.name}")
@dataclass
@@ -88,6 +89,7 @@ class TrainCycleItem:
fwd (float): the item for the forward pass
bwd (float): the item for the backward pass
"""
+
fwd: Any
bwd: Any
total: Any
@@ -104,6 +106,7 @@ class MemoryCost:
temp (int): the memory cost incurred by the temporary tensors in bytes.
buffer (int): the memory cost incurred by the module buffer in bytes.
"""
+
activation: int = 0
parameter: int = 0
temp: int = 0
@@ -120,6 +123,7 @@ class CommType(Enum):
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
+
BEFORE = 0
AFTER = 1
HOOK = 2
@@ -137,6 +141,7 @@ class CommAction:
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
+
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
@@ -156,6 +161,7 @@ class ShardingStrategy:
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
"""
+
name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None
@@ -200,7 +206,6 @@ def get_sharding_spec_by_name(self, name: str):
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self):
-
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
@@ -209,31 +214,34 @@ def _deepcopy_dict_vals(data: Dict):
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
- communication_actions = _deepcopy_dict_vals(
- self.communication_actions) if self.communication_actions is not None else None
+ communication_actions = (
+ _deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None
+ )
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)
- return ShardingStrategy(name=self.name,
- sharding_specs=sharding_specs,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- communication_actions=communication_actions,
- resharding_costs=resharding_costs)
+ return ShardingStrategy(
+ name=self.name,
+ sharding_specs=sharding_specs,
+ compute_cost=compute_cost,
+ communication_cost=communication_cost,
+ memory_cost=memory_cost,
+ communication_actions=communication_actions,
+ resharding_costs=resharding_costs,
+ )
class StrategiesVector(list):
- '''
+ """
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
- '''
+ """
def __init__(self, node: Node):
super().__init__()
@@ -245,7 +253,7 @@ def __init__(self, node: Node):
def check_merge(self):
merge_label = False
- if self.node.op == 'call_module':
+ if self.node.op == "call_module":
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
@@ -255,7 +263,7 @@ def check_merge(self):
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
- if self.node.op == 'call_function':
+ if self.node.op == "call_function":
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
@@ -267,7 +275,7 @@ def check_merge(self):
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
- if self.node.op == 'call_method':
+ if self.node.op == "call_method":
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
index f9e6bd923921..b930ce80a9b9 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
@@ -3,4 +3,4 @@
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
-__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
+__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"]
diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
index 1b2d3ad57407..4415d429b0c2 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
@@ -4,7 +4,7 @@
class CostGraph:
- '''
+ """
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
@@ -15,7 +15,7 @@ class CostGraph:
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
- '''
+ """
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
@@ -39,10 +39,10 @@ def _remove_invalid_node(self, node, attr_name):
target_node_list.remove(element)
def _build_cost_graph(self):
- '''
+ """
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
- '''
+ """
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
@@ -84,8 +84,8 @@ def _check_tensor_in_node(data):
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
- setattr(dst_node, 'parents', parent_nodes)
- setattr(dst_node, 'children', children_nodes)
+ setattr(dst_node, "parents", parent_nodes)
+ setattr(dst_node, "children", children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
@@ -99,7 +99,7 @@ def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
- '''
+ """
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
@@ -119,7 +119,7 @@ def merge_node(self, src_node, dst_node):
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
- '''
+ """
# build merge_map
merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector):
@@ -196,7 +196,7 @@ def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
- for (src_node, dst_node) in self.merge_pair:
+ for src_node, dst_node in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
index 171aa8b3399f..678965d663e4 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
@@ -7,7 +7,7 @@
from colossalai.fx.passes.utils import get_node_module
-__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
+__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
@dataclass
@@ -15,6 +15,7 @@ class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
+
name: str
node: Node
is_inplace: bool
@@ -55,6 +56,7 @@ class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
+
name: str
node: Node
all_live_vars: LiveVariableVector
@@ -62,7 +64,6 @@ class LiveStage:
class GraphAnalyser:
-
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@@ -105,18 +106,18 @@ def liveness_analysis(self) -> List[LiveStage]:
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplicate var
is_inplace = False
- if node.op == 'call_function':
+ if node.op == "call_function":
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
- if node.kwargs.get('inplace', False):
+ if node.kwargs.get("inplace", False):
is_inplace = True
- elif node.op == 'call_module':
+ elif node.op == "call_module":
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
- if getattr(module, 'inplace', False):
+ if getattr(module, "inplace", False):
is_inplace = True
# add the output var
- meta = getattr(node, '_meta_data', None)
+ getattr(node, "_meta_data", None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
@@ -138,10 +139,12 @@ def liveness_analysis(self) -> List[LiveStage]:
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
- stage = LiveStage(name=node.name,
- node=node,
- all_live_vars=all_live_variables.copy(),
- unique_live_vars=unique_live_vars.copy())
+ stage = LiveStage(
+ name=node.name,
+ node=node,
+ all_live_vars=all_live_variables.copy(),
+ unique_live_vars=unique_live_vars.copy(),
+ )
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py
index 564c5f09220c..088d1acb5177 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/solver.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py
@@ -21,24 +21,25 @@
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
- warnings.warn(f'please install the pulp')
+ warnings.warn(f"please install the pulp")
-__all___ = ['Solver']
+__all___ = ["Solver"]
class Solver:
-
- def __init__(self,
- graph: Graph,
- strategies_constructor: StrategiesConstructor,
- cost_graph: CostGraph,
- graph_analyser: GraphAnalyser = None,
- memory_budget: float = -1.0,
- solution_numbers: int = 1,
- forward_only: bool = False,
- memory_increasing_coefficient: float = 1.3,
- verbose=False):
- '''
+ def __init__(
+ self,
+ graph: Graph,
+ strategies_constructor: StrategiesConstructor,
+ cost_graph: CostGraph,
+ graph_analyser: GraphAnalyser = None,
+ memory_budget: float = -1.0,
+ solution_numbers: int = 1,
+ forward_only: bool = False,
+ memory_increasing_coefficient: float = 1.3,
+ verbose=False,
+ ):
+ """
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
@@ -48,7 +49,7 @@ def __init__(self,
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
- '''
+ """
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
@@ -75,11 +76,11 @@ def __init__(self,
self.verbose = verbose
def _recover_merged_node_strategy(self):
- '''
+ """
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
- '''
+ """
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
@@ -98,9 +99,9 @@ def _generate_node_index_dict(self) -> Dict[Node, int]:
return node_index_dict
def _prepare_data_for_solver(self):
- '''
+ """
Extract information from components for solver.
- '''
+ """
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
@@ -190,23 +191,40 @@ def _prepare_data_for_solver(self):
# omit initial value for nodes
s_init_np = None
- return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
-
- def _call_solver_serialized_args(self,
- node_nums,
- memory_budget,
- strategies_len,
- following_nodes,
- edge_pairs,
- alias_set,
- liveness_set,
- compute_costs,
- communication_costs,
- memory_costs,
- resharding_costs,
- alias_convert_costs,
- s_init_np=None,
- verbose=True):
+ return (
+ node_nums,
+ memory_budget,
+ strategies_len,
+ following_nodes,
+ edge_pairs,
+ alias_set,
+ liveness_set,
+ compute_costs,
+ communication_costs,
+ memory_costs,
+ resharding_costs,
+ alias_convert_costs,
+ s_init_np,
+ self.verbose,
+ )
+
+ def _call_solver_serialized_args(
+ self,
+ node_nums,
+ memory_budget,
+ strategies_len,
+ following_nodes,
+ edge_pairs,
+ alias_set,
+ liveness_set,
+ compute_costs,
+ communication_costs,
+ memory_costs,
+ resharding_costs,
+ alias_convert_costs,
+ s_init_np=None,
+ verbose=True,
+ ):
"""
Call the solver with serialized arguments.
"""
@@ -235,18 +253,18 @@ def get_non_zero_index(binary_vector):
s_follow = following_nodes
s_alias = alias_set
- E = edge_pairs.reshape((-1, 2)) # noqa
+ E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
- for (i, j) in E:
+ for i, j in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
- r.append(resharding_costs[pt:pt + prod_length])
+ r.append(resharding_costs[pt : pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
@@ -268,7 +286,6 @@ def get_non_zero_index(binary_vector):
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
- v = []
pt = 0
c = []
@@ -277,9 +294,9 @@ def get_non_zero_index(binary_vector):
pt = 0
for i in range(node_nums):
length = strategies_len[i]
- c.append(compute_costs[pt:pt + length])
- d.append(communication_costs[pt:pt + length])
- m.append(memory_costs[pt:pt + length])
+ c.append(compute_costs[pt : pt + length])
+ d.append(communication_costs[pt : pt + length])
+ m.append(memory_costs[pt : pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
@@ -319,7 +336,7 @@ def get_non_zero_index(binary_vector):
e = []
num_edges = 0
map_edge_to_idx = {}
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
@@ -340,7 +357,7 @@ def get_non_zero_index(binary_vector):
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
- for (idx, value, fix) in s_init:
+ for idx, value, fix in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
@@ -393,7 +410,7 @@ def get_non_zero_index(binary_vector):
# (d). specified by `cat="Binary"`
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
@@ -402,13 +419,13 @@ def get_non_zero_index(binary_vector):
# (f)
for row in range(len(s[i])):
- C = len(s[j]) # noqa
+ C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
- R = len(s[i]) # noqa
- C = len(s[j]) # noqa
+ R = len(s[i]) # noqa
+ C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
@@ -434,7 +451,8 @@ def get_non_zero_index(binary_vector):
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
- onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
+ onlyAvailable=True
+ ), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
@@ -444,13 +462,13 @@ def get_non_zero_index(binary_vector):
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
- print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
- f"Time: {time.time() - tic}")
+ print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
- raise RuntimeError("Cannot run the function under the given memory budget. "
- "Please increase the memory budget.")
+ raise RuntimeError(
+ "Cannot run the function under the given memory budget. " "Please increase the memory budget."
+ )
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
@@ -458,7 +476,7 @@ def get_non_zero_index(binary_vector):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
index 044a8ac847ea..aa87ee9bf3db 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
@@ -1,11 +1,5 @@
-import builtins
-import math
-import operator
-from copy import deepcopy
-from typing import Dict, List
-
import torch
-from torch.fx import Graph, Node
+from torch.fx import Graph
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
@@ -14,13 +8,12 @@
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
-from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
-__all__ = ['StrategiesConstructor']
+__all__ = ["StrategiesConstructor"]
class StrategiesConstructor:
@@ -35,7 +28,7 @@ class StrategiesConstructor:
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
- assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
+ assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
@@ -46,11 +39,11 @@ def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: Solver
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
- '''
+ """
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
- '''
+ """
name_checklist = []
remove_list = []
for strategy in strategies_vector:
@@ -62,7 +55,6 @@ def remove_duplicated_strategy(self, strategies_vector):
strategies_vector.remove(strategy)
def generate_alias_set(self):
-
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
@@ -83,7 +75,7 @@ def build_strategies_and_cost(self):
"""
def _check_no_strategy_for_node(node):
- if node.op in ('placeholder', 'get_attr', 'output'):
+ if node.op in ("placeholder", "get_attr", "output"):
return False
def _check_no_strategy_for_data(data):
@@ -102,83 +94,93 @@ def _check_no_strategy_for_data(data):
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
- pass
# placeholder node
- elif node.op == 'placeholder':
+ elif node.op == "placeholder":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
- placeholder_option = 'distributed'
+ placeholder_option = "distributed"
else:
- assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
- placeholder_option = 'replicated'
- placeholder_handler = PlaceholderHandler(node,
- self.device_mesh,
- strategies_vector,
- placeholder_option=placeholder_option)
+ assert (
+ self.solver_options.dataloader_option == DataloaderOption.REPLICATED
+ ), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
+ placeholder_option = "replicated"
+ placeholder_handler = PlaceholderHandler(
+ node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option
+ )
placeholder_handler.register_strategy()
# get_attr node
- elif node.op == 'get_attr':
- getattr_handler = GetattrHandler(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ elif node.op == "get_attr":
+ getattr_handler = GetattrHandler(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
getattr_handler.register_strategy()
# call_module node
- elif node.op == 'call_module':
+ elif node.op == "call_module":
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
- handler = operator_registry.get(submod_type)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(submod_type)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# call_function node
- elif node.op == 'call_function':
+ elif node.op == "call_function":
target = node.target
- handler = operator_registry.get(target)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(target)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# call_method node
- elif node.op == 'call_method':
+ elif node.op == "call_method":
method = getattr(node.args[0]._meta_data.__class__, node.target)
- handler = operator_registry.get(method)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(method)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# output node
- elif node.op == 'output':
+ elif node.op == "output":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
- output_option = 'distributed'
+ output_option = "distributed"
else:
- assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
- output_option = 'replicated'
+ assert (
+ self.solver_options.dataloader_option == DataloaderOption.REPLICATED
+ ), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
+ output_option = "replicated"
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector)
- setattr(node, 'strategies_vector', strategies_vector)
+ setattr(node, "strategies_vector", strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
index b7fe5430bf13..d61cfd2add15 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
@@ -17,9 +17,21 @@
)
__all__ = [
- 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
- 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
- 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
- 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
- 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
+ "BroadcastType",
+ "get_broadcast_shape",
+ "is_broadcastable",
+ "recover_sharding_spec_for_broadcast_shape",
+ "generate_resharding_costs",
+ "generate_sharding_spec",
+ "ignore_sharding_exception",
+ "check_sharding_spec_validity" "transpose_partition_dim",
+ "update_partition_dim",
+ "enumerate_all_possible_1d_sharding",
+ "enumerate_all_possible_2d_sharding",
+ "generate_sharding_size",
+ "comm_actions_for_oprands",
+ "pytree_map",
+ "detect_reshape_mapping",
+ "check_keep_sharding_status",
+ "infer_output_dim_partition_dict",
]
diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
index 307348ea1eaf..99d5a0f2a942 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
@@ -14,8 +14,11 @@
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
- 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
- 'comm_actions_for_oprands'
+ "BroadcastType",
+ "is_broadcastable",
+ "get_broadcast_shape",
+ "recover_sharding_spec_for_broadcast_shape",
+ "comm_actions_for_oprands",
]
@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
- assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
+ assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable"
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
- assert logical_num_dims >= physical_num_dims, \
- 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
+ assert (
+ logical_num_dims >= physical_num_dims
+ ), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!"
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}
@@ -85,8 +89,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
return logical_dim_broadcast_info
-def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
- physical_shape: torch.Size) -> ShardingSpec:
+def recover_sharding_spec_for_broadcast_shape(
+ logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size
+) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
@@ -124,15 +129,18 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim
- physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
- entire_shape=physical_shape,
- dim_partition_dict=physical_dim_partition)
+ physical_sharding_spec = ShardingSpec(
+ device_mesh=logical_sharding_spec.device_mesh,
+ entire_shape=physical_shape,
+ dim_partition_dict=physical_dim_partition,
+ )
return physical_sharding_spec, removed_dims
-def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
- sharding_spec: ShardingSpec) -> CommAction:
+def comm_actions_for_oprands(
+ node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec
+) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
- comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
- sharding_spec=sharding_spec,
- logical_process_axis=removed_dims)
+ comm_spec = CommSpec(
+ comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
+ sharding_spec=sharding_spec,
+ logical_process_axis=removed_dims,
+ )
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
- assert arg_index >= 0, f'op_data should be an argument of node.'
+ assert arg_index >= 0, f"op_data should be an argument of node."
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py
index 347c10aa102d..aaca923a5eee 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/factory.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py
@@ -14,11 +14,12 @@
from ..constants import INFINITY_COST
-__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
+__all__ = ["generate_sharding_spec", "generate_resharding_costs"]
-def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
- dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
+def generate_sharding_spec(
+ input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]]
+) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
@@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
"""
if isinstance(input_, Node):
- assert hasattr(input_, '_meta_data'), f'The given node has no attribute _meta_data'
+ assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data"
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
@@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
shape = input_.shape
else:
raise TypeError(
- f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
+ f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected."
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
- assert shape[
- dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
+ assert (
+ shape[dim_index] % sharding_size == 0
+ ), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions."
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
-def generate_resharding_costs(nodes: List[Node],
- sharding_specs: List[ShardingSpec],
- count_backward: Optional[bool] = True,
- dtype: Optional[torch.dtype] = None,
- index=None):
- '''
+def generate_resharding_costs(
+ nodes: List[Node],
+ sharding_specs: List[ShardingSpec],
+ count_backward: Optional[bool] = True,
+ dtype: Optional[torch.dtype] = None,
+ index=None,
+):
+ """
Compute the resharding costs with this specific strategy.
Argument:
@@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node],
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
- '''
+ """
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
@@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node],
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
- assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
+ assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected."
input_sharding_spec = input_sharding_spec[index]
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
+ assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor."
try:
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
- input_sharding_spec, input_spec)
+ input_sharding_spec, input_spec
+ )
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
- warnings.warn(f'{e}')
+ warnings.warn(f"{e}")
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
- '''
+ """
Find the largest repeat blocks in the graph, whose length is larger than the threshold.
Args:
gm (GraphModule): the graph module to be analyzed.
common_length_threshold (int): the threshold of the repeat block length.
- '''
+ """
# graph = gm.graph
def _process_args(args):
new_args = []
for arg in args:
- if hasattr(arg, '_meta_data'):
+ if hasattr(arg, "_meta_data"):
meta_data = arg._meta_data
else:
meta_data = arg
@@ -145,7 +150,7 @@ def _check_node_equal(node1, node2):
return False
for index, node in enumerate(node_list):
- if node.op == 'call_module':
+ if node.op == "call_module":
target = node.target
submod = root_module.get_submodule(target)
submod_type = type(submod)
@@ -155,12 +160,12 @@ def _check_node_equal(node1, node2):
new_args = _process_args(node.args)
- if node.op != 'get_attr':
+ if node.op != "get_attr":
hash_key = (node.op, target, *new_args)
else:
hash_key = (node.op,)
- setattr(node, 'hash_key', hash_key)
+ setattr(node, "hash_key", hash_key)
hash_value_to_node_dict = {}
@@ -179,7 +184,7 @@ def _check_node_equal(node1, node2):
# the comparison will be triggered if a common node appears
if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
- check_block_list = [node_list[start:start + max_common_length] for start in start_index_list]
+ check_block_list = [node_list[start : start + max_common_length] for start in start_index_list]
common_label = True
if not _all_equal(check_block_list, _check_node_list_equal):
@@ -201,6 +206,6 @@ def _check_node_equal(node1, node2):
# recover common subgraph from the index
common_blocks = []
for start in common_blocks_index:
- common_blocks.append(node_list[start:start + max_common_length])
+ common_blocks.append(node_list[start : start + max_common_length])
return common_blocks
diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py
index 475e95fc4326..42ec2a8ee428 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/misc.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py
@@ -1,12 +1,12 @@
import functools
-from typing import Any, Callable, Dict, List, Tuple, Type, Union
+from typing import Any, Callable, Tuple, Type, Union
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
-__all__ = ['ignore_sharding_exception', 'pytree_map']
+__all__ = ["ignore_sharding_exception", "pytree_map"]
def ignore_sharding_exception(func):
@@ -48,29 +48,32 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.shape[1]
- assert sharding_len == tensor_num_dim, \
- f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
+ assert (
+ sharding_len == tensor_num_dim
+ ), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})."
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
- if str(dim_spec).startswith('S'):
- devices_str = str(dim_spec).lstrip('S')
+ if str(dim_spec).startswith("S"):
+ devices_str = str(dim_spec).lstrip("S")
num_devices = 1
- if '0' in devices_str:
+ if "0" in devices_str:
num_devices *= num_devices_in_col
- if '1' in devices_str:
+ if "1" in devices_str:
num_devices *= num_devices_in_row
- assert dim_size >= num_devices and dim_size % num_devices == 0, \
- f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
+ assert (
+ dim_size >= num_devices and dim_size % num_devices == 0
+ ), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices."
# make sure the entire shape matches the physical tensor shape
- assert sharding_spec.entire_shape == tensor.shape, \
- f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
+ assert (
+ sharding_spec.entire_shape == tensor.shape
+ ), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}"
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
index d0ebbd7e8b1b..329312ef797f 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/reshape.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
@@ -8,6 +8,7 @@ class PreviousStatus(Enum):
"""
This class shows the status of previous comparison.
"""
+
RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1
@@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
return reshape_mapping_dict
-def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
- reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
+def check_keep_sharding_status(
+ input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
+) -> bool:
"""
This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status.
@@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
return True
-def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
- reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
+def infer_output_dim_partition_dict(
+ input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
+) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict.
"""
- assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
- 'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
+ assert check_keep_sharding_status(
+ input_dim_partition_dict, reshape_mapping_dict
+ ), "we only infer output dim partition dict for the reshape operation could keep sharding spec."
sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items():
diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py
index e2ce59e0b577..b5386d599be4 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py
@@ -8,8 +8,11 @@
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
- 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
- 'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
+ "transpose_partition_dim",
+ "update_partition_dim",
+ "enumerate_all_possible_1d_sharding",
+ "enumerate_all_possible_2d_sharding",
+ "generate_sharding_size",
]
@@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
- assert len(sharding_spec.entire_shape) >= 2, \
- 'The entire_shape of the sharding spec must have at least 2 dimensions'
+ assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions"
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
@@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
return sharding_spec
-def update_partition_dim(sharding_spec: ShardingSpec,
- dim_mapping: Dict[int, int],
- physical_shape: torch.Size,
- inplace: bool = False):
+def update_partition_dim(
+ sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False
+):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
@@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec,
new_dim_partition_dict[tensor_dim] = mesh_dims
# update sharding spec
- current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
- entire_shape=physical_shape,
- dim_partition_dict=new_dim_partition_dict)
+ current_sharding_spec.__init__(
+ device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict
+ )
return current_sharding_spec
diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py
index cc98c1570b4a..9571fa2c17f0 100644
--- a/colossalai/autochunk/autochunk_codegen.py
+++ b/colossalai/autochunk/autochunk_codegen.py
@@ -9,7 +9,18 @@
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
- from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
+ from torch.fx.graph import (
+ CodeGen,
+ PythonCode,
+ _custom_builtins,
+ _CustomBuiltin,
+ _format_target,
+ _is_from_torch,
+ _Namespace,
+ _origin_type_map,
+ inplace_methods,
+ magic_methods,
+ )
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
- tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
- input_node.name)
- tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
+ tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
+ shape_str,
+ input_node.name,
+ input_node.name,
+ )
+ tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
- context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
- input_node.name, input_node.name)
+ context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
+ chunk_output[i].name,
+ shape_str,
+ input_node.name,
+ input_node.name,
+ )
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_output_dim[0]]
@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
return context
-def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
- chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
+def _gen_loop_end(
+ chunk_inputs: List[Node],
+ chunk_non_compute_inputs: List[Node],
+ node_list: List[Node],
+ chunk_outputs_idx: int,
+ chunk_outputs_non_tensor: List[Node],
+ search_chunk: SearchChunk,
+) -> str:
"""
Generate chunk loop end
@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
- if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
- or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
+ if (
+ source_node not in chunk_infos[region_idx]["node_chunk_dim"]
+ or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
+ ):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
@@ -203,11 +229,12 @@ def _add_node_slice(
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
- chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
- get_node_shape(chunk_node))
+ chunk_slice = _gen_chunk_slice_dim(
+ chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
+ )
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
- for i in range(len(chunk_node.meta['tensor_meta'])):
+ for i in range(len(chunk_node.meta["tensor_meta"])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
@@ -216,13 +243,15 @@ def _add_node_slice(
return body
-def emit_code_with_chunk(body: List[str],
- nodes: Iterable[Node],
- emit_node_func: Callable,
- delete_unused_value_func: Callable,
- search_chunk: SearchChunk,
- chunk_infos: List,
- eval_mem: bool = False):
+def emit_code_with_chunk(
+ body: List[str],
+ nodes: Iterable[Node],
+ emit_node_func: Callable,
+ delete_unused_value_func: Callable,
+ search_chunk: SearchChunk,
+ chunk_infos: List,
+ eval_mem: bool = False,
+):
"""
Emit code with chunk according to chunk_infos.
@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs
- chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
- chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
- chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
+ chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
+ chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
+ chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
- ))
+ )
+ )
if within_chunk_region:
emit_node_func(node, body)
@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
- % (node.name))
+ % (node.name)
+ )
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
- % (node.name))
+ % (node.name)
+ )
# generate chunk region end
if node_idx in chunk_ends:
body.append(
- _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
- chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
+ _gen_loop_end(
+ chunk_inputs[region_idx],
+ chunk_inputs_non_chunk[region_idx],
+ node_list,
+ chunk_ends[region_idx],
+ chunk_outputs_non_tensor[region_idx],
+ search_chunk,
+ )
+ )
within_chunk_region = False
node_idx += 1
@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
-
- def __init__(self,
- meta_graph,
- max_memory: int = None,
- print_mem: bool = False,
- print_progress: bool = False,
- eval_mem: bool = False) -> None:
+ def __init__(
+ self,
+ meta_graph,
+ max_memory: int = None,
+ print_mem: bool = False,
+ print_progress: bool = False,
+ eval_mem: bool = False,
+ ) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
@@ -349,7 +389,7 @@ def add_global(name_hint: str, obj: Any):
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -402,7 +442,6 @@ def type_repr(o: Any):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
@@ -457,10 +496,10 @@ def delete_unused_values(user: Node, body, to_keep=[]):
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
@@ -470,42 +509,56 @@ def emit_node(node: Node, body):
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
- f"({_format_args(node.args[1:], node.kwargs)})")
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f"{repr(node)}{maybe_type_annotation} = "
- f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
- body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
- f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
- and node.args[1].isidentifier() and len(node.args) == 2):
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f"{repr(node)}{maybe_type_annotation} = "
- f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
@@ -523,8 +576,9 @@ def emit_node(node: Node, body):
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
- self.eval_mem)
+ emit_code_with_chunk(
+ body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
+ )
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py
index 77bc2ef17bc3..a85ad429e261 100644
--- a/colossalai/autochunk/estimate_memory.py
+++ b/colossalai/autochunk/estimate_memory.py
@@ -1,11 +1,8 @@
-import copy
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Dict, List
import torch
from torch.fx.node import Node
-from colossalai.fx.profiler import activation_size, parameter_size
-
from .utils import NodeMgr, get_node_shape, is_non_memory_node
@@ -62,12 +59,9 @@ def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict:
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
- def _remove_deactive_node(self,
- user_idx: int,
- user: Node,
- active_nodes: List,
- delete_node_dict: List,
- kept_nodes: List = None) -> None:
+ def _remove_deactive_node(
+ self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
+ ) -> None:
"""
remove deactivate nodes from active nodes
"""
@@ -169,7 +163,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
- chunk_ratio = 1 # use it to estimate chunk mem
+ chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_all = []
if use_chunk:
@@ -184,7 +178,6 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()):
-
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
@@ -193,8 +186,9 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
# determine chunk ratio for current node
if chunk_within:
- chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
- chunk_sizes[chunk_region_idx])
+ chunk_ratio = self._get_chunk_ratio(
+ node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
+ )
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
@@ -222,7 +216,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
- self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
+ self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py
index 59645c80e808..1c599049d9eb 100644
--- a/colossalai/autochunk/search_chunk.py
+++ b/colossalai/autochunk/search_chunk.py
@@ -8,7 +8,7 @@
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
-from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
+from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@@ -121,8 +121,10 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re
# check if peak node already in chunk info
if chunk_regions is not None:
for i in chunk_regions:
- if i["region"][0] < peak_region[0] <= i["region"][1] or \
- i["region"][0] < peak_region[1] <= i["region"][1]:
+ if (
+ i["region"][0] < peak_region[0] <= i["region"][1]
+ or i["region"][0] < peak_region[1] <= i["region"][1]
+ ):
return None
active_node_num = [len(i) for i in active_node]
@@ -146,9 +148,9 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
- elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
+ elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
chunk_region_start = region[1] + 1
- elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
+ elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
@@ -171,7 +173,7 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
- if len(start_traces) > 1: # TODO need to be removed
+ if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
@@ -180,8 +182,9 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
- if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
- end_idx):
+ if not self.trace_flow.check_region_start_end(
+ start_node, start_dim, start_idx, end_node, end_dim, end_idx
+ ):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
@@ -203,7 +206,7 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
- input_trace = [] # trace of a node's input nodes
+ input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
@@ -215,7 +218,8 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
- self.node_mgr.get_node_by_idx(end_idx)):
+ self.node_mgr.get_node_by_idx(end_idx)
+ ):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
@@ -279,15 +283,18 @@ def search_region(self) -> Dict:
chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
- self.node_mgr.get_node_list(), chunk_infos)
+ self.node_mgr.get_node_list(), chunk_infos
+ )
if self.print_progress:
- get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
- (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
+ get_logger().info(
+ "AutoChunk find chunk region %d = (%d, %d)"
+ % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
+ )
if self.print_mem:
self.print_mem = False
- self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
- chunk_infos,
- print_mem=True)
+ self.estimate_memory.estimate_chunk_inference_mem(
+ self.node_mgr.get_node_list(), chunk_infos, print_mem=True
+ )
return chunk_infos
diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py
index 94a29bfd5691..8a60ba681f70 100644
--- a/colossalai/autochunk/select_chunk.py
+++ b/colossalai/autochunk/select_chunk.py
@@ -5,7 +5,6 @@
class SelectChunk(object):
-
def __init__(
self,
trace_indice: TraceIndice,
@@ -20,7 +19,7 @@ def __init__(
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
- self.max_memory = max_memory # MB
+ self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
@@ -57,16 +56,18 @@ def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, m
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
- cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
+ cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
- regions_dict.append({
- "chunk_info": region,
- "chunk_max_mem": cur_chunk_region_max_peak,
- "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
- "reorder_chunk_info": cur_region,
- "reorder_node_list": cur_node_list,
- })
+ regions_dict.append(
+ {
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ }
+ )
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
@@ -90,13 +91,15 @@ def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
- cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
- cur_chunk_infos)[0]
- cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
+ chunk_region_dict["reorder_node_list"], cur_chunk_infos
+ )[0]
+ cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
- chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
- chunk_infos)
+ chunk_info["chunk_size"] = self._chunk_size_binary_search(
+ chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
+ )
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
@@ -109,9 +112,10 @@ def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos)
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
- cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
- cur_chunk_infos)[0]
- cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
+ chunk_region_dict["reorder_node_list"], cur_chunk_infos
+ )[0]
+ cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
@@ -139,8 +143,10 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
return None
# get max possible chunk region
- max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
- max([i["region"][1] for i in possible_chunk_regions]))
+ max_possible_chunk_region = (
+ min([i["region"][0] for i in possible_chunk_regions]),
+ max([i["region"][1] for i in possible_chunk_regions]),
+ )
# get mem for chunk region
regions_dict_list = []
@@ -149,15 +155,17 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
- cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
+ cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
- regions_dict_list.append({
- "chunk_info": region,
- "chunk_max_mem": cur_chunk_region_max_peak,
- "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
- "reorder_chunk_info": cur_region,
- "reorder_node_list": cur_node_list,
- })
+ regions_dict_list.append(
+ {
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ }
+ )
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
@@ -175,7 +183,9 @@ def _is_legal_region(self, cur_chunk_info, chunk_infos):
return False
for i in chunk_infos:
region = i["region"]
- if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
- (chunk_region_start < region[0] and chunk_region_end < region[0])):
+ if not (
+ (chunk_region_start > region[1] and chunk_region_end > region[1])
+ or (chunk_region_start < region[0] and chunk_region_end < region[0])
+ ):
return False
return True
diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py
index a1080fda1541..8b36c99bbadd 100644
--- a/colossalai/autochunk/trace_flow.py
+++ b/colossalai/autochunk/trace_flow.py
@@ -16,7 +16,6 @@
class TraceFlow(object):
-
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
@@ -151,7 +150,7 @@ def _assign_single_node_flow(
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
- cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
+ cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@@ -266,7 +265,7 @@ def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int,
maybe_prepose_nodes.sort(
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
- ) # from last node to first node
+ ) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
@@ -328,7 +327,8 @@ def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
- self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
+ self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
+ )
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
@@ -371,8 +371,9 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim):
return chunk_info
- def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
- chunk_info: Dict):
+ def _get_other_output_info(
+ self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
+ ):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
@@ -384,8 +385,8 @@ def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim:
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
- if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
- chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
+ if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
+ chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
@@ -421,7 +422,8 @@ def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output:
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
- set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
+ set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
+ )
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
@@ -443,8 +445,11 @@ def _reassign_reshape_size(self, chunk_info):
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
- if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
- reshape_args[0].meta['fwd_out']) > 1:
+ if (
+ len(reshape_args) == 1
+ and get_node_shape(reshape_args[0]) is None
+ and len(reshape_args[0].meta["fwd_out"]) > 1
+ ):
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
@@ -462,16 +467,17 @@ def _reassign_reshape_size(self, chunk_info):
chunk_info["reshape_size"] = reshape_size
return chunk_info
- def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
- end_idx: int) -> bool:
+ def check_region_start_end(
+ self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
+ ) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
- if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
+ if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
return False
# dim size cannot be 1
- if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
+ if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
return False
# must have users
if len(end_node.users) == 0:
diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py
index fbe0741b8827..378c54acf782 100644
--- a/colossalai/autochunk/trace_indice.py
+++ b/colossalai/autochunk/trace_indice.py
@@ -1,5 +1,5 @@
import copy
-from typing import Dict, List, Tuple
+from typing import Dict, List
from torch.fx.node import Node
@@ -412,7 +412,7 @@ def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None:
node_idx (int)
"""
# get conv input
- assert node.kwargs['size'] is None
+ assert node.kwargs["size"] is None
assert len(get_node_shape(node)) == 4
# assign index
@@ -826,7 +826,7 @@ def _clear_trace(self, node_idx: int) -> None:
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
- if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
+ if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
@@ -876,10 +876,24 @@ def trace_indice(self) -> None:
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
- elif any(n == node_name for n in [
- "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
- "sin", "cos"
- ]):
+ elif any(
+ n == node_name
+ for n in [
+ "mul",
+ "add",
+ "sigmoid",
+ "relu",
+ "sub",
+ "truediv",
+ "pow",
+ "dropout",
+ "where",
+ "tanh",
+ "exp",
+ "sin",
+ "cos",
+ ]
+ ):
self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
@@ -920,7 +934,7 @@ def trace_indice(self) -> None:
else:
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
- self._assign_all_indice(node, idx) # get param
+ self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:
diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py
index 064baa047155..f6f803a5ce0a 100644
--- a/colossalai/autochunk/utils.py
+++ b/colossalai/autochunk/utils.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
+from typing import Any, Dict, List, Union
from torch.fx.node import Node
@@ -10,7 +10,6 @@
class NodeMgr(object):
-
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
- if (input_node not in nodes and input_node not in input_nodes
- and not is_non_compute_node_except_placeholder(input_node)):
+ if (
+ input_node not in nodes
+ and input_node not in input_nodes
+ and not is_non_compute_node_except_placeholder(input_node)
+ ):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
- if (output_node not in nodes and node not in output_nodes
- and not is_non_compute_node_except_placeholder_output(output_node)):
+ if (
+ output_node not in nodes
+ and node not in output_nodes
+ and not is_non_compute_node_except_placeholder_output(output_node)
+ ):
output_nodes.append(node)
return input_nodes, output_nodes
@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
- elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
- node.meta['fwd_out'][0], int):
+ elif (
+ len(node.meta["fwd_out"]) > 0
+ and isinstance(node.meta["fwd_out"], list)
+ and isinstance(node.meta["fwd_out"][0], int)
+ ):
out.append(node)
return out
diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py
index fc2c4a40068b..92990907bc2e 100644
--- a/colossalai/booster/accelerator.py
+++ b/colossalai/booster/accelerator.py
@@ -1,12 +1,11 @@
import torch
import torch.nn as nn
-__all__ = ['Accelerator']
+__all__ = ["Accelerator"]
_supported_devices = [
- 'cpu',
- 'cuda',
-
+ "cpu",
+ "cuda",
# To be supported
# 'xpu',
# 'npu',
@@ -25,21 +24,22 @@ class Accelerator:
def __init__(self, device: str):
self.device = device
- assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
+ assert (
+ self.device in _supported_devices
+ ), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
def bind(self):
"""
Set the default device for the current process.
"""
- if self.device == 'cpu':
+ if self.device == "cpu":
pass
- elif self.device == 'cuda':
+ elif self.device == "cuda":
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
- torch.cuda.set_device(torch.device('cuda'))
- pass
+ torch.cuda.set_device(torch.device("cuda"))
else:
raise ValueError(f"Device {self.device} is not supported yet")
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
index fb9dae7c9650..2aee72cbf2f1 100644
--- a/colossalai/booster/booster.py
+++ b/colossalai/booster/booster.py
@@ -16,7 +16,7 @@
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase
-__all__ = ['Booster']
+__all__ = ["Booster"]
class Booster:
@@ -60,28 +60,31 @@ class Booster:
plugin (Plugin): The plugin to run the training. Default: None.
"""
- def __init__(self,
- device: Optional[str] = None,
- mixed_precision: Optional[Union[MixedPrecision, str]] = None,
- plugin: Optional[Plugin] = None) -> None:
+ def __init__(
+ self,
+ device: Optional[str] = None,
+ mixed_precision: Optional[Union[MixedPrecision, str]] = None,
+ plugin: Optional[Plugin] = None,
+ ) -> None:
if plugin is not None:
assert isinstance(
- plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
+ plugin, Plugin
+ ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
if device is not None:
- warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
+ warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
else:
- device = device or 'cuda'
+ device = device or "cuda"
self.accelerator = Accelerator(device)
# set precision
if self.plugin and self.plugin.control_precision():
if mixed_precision is not None:
- warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
+ warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
@@ -95,7 +98,7 @@ def __init__(self,
self.mixed_precision = mixed_precision
else:
raise ValueError(
- f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
+ f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}."
)
if self.plugin is not None and self.plugin.control_checkpoint_io():
@@ -131,7 +134,8 @@ def boost(
# transform model for mixed precision
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
- model, optimizer, criterion, dataloader, lr_scheduler)
+ model, optimizer, criterion, dataloader, lr_scheduler
+ )
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
@@ -154,13 +158,15 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
# TODO(frank lee): implement this method with plugin
optimizer.backward(loss)
- def execute_pipeline(self,
- data_iter: Iterator,
- model: nn.Module,
- criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Optional[Optimizer] = None,
- return_loss: bool = True,
- return_outputs: bool = False) -> Dict[str, Any]:
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: nn.Module,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[Optimizer] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> Dict[str, Any]:
"""
Execute forward & backward when utilizing pipeline parallel.
Return loss or Huggingface style model outputs if needed.
@@ -185,8 +191,9 @@ def execute_pipeline(self,
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
"""
- assert isinstance(self.plugin,
- PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
+ assert isinstance(
+ self.plugin, PipelinePluginBase
+ ), f"The plugin {self.plugin.__class__.__name__} does not support pipeline."
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
@@ -200,8 +207,10 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
Returns:
contextmanager: Context to disable gradient synchronization.
"""
- assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
- assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
+ assert (
+ self.plugin is not None
+ ), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync."
+ assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
@@ -217,14 +226,16 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
- def save_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- shard: bool = False,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False) -> None:
+ def save_model(
+ self,
+ model: Union[nn.Module, ModelWrapper],
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
"""Save model to checkpoint.
Args:
@@ -239,13 +250,15 @@ def save_model(self,
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
"""
- self.checkpoint_io.save_model(model,
- checkpoint=checkpoint,
- shard=shard,
- gather_dtensor=gather_dtensor,
- prefix=prefix,
- size_per_shard=size_per_shard,
- use_safetensors=use_safetensors)
+ self.checkpoint_io.save_model(
+ model,
+ checkpoint=checkpoint,
+ shard=shard,
+ gather_dtensor=gather_dtensor,
+ prefix=prefix,
+ size_per_shard=size_per_shard,
+ use_safetensors=use_safetensors,
+ )
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint.
@@ -260,13 +273,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
- def save_optimizer(self,
- optimizer: Optimizer,
- checkpoint: str,
- shard: bool = False,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024) -> None:
+ def save_optimizer(
+ self,
+ optimizer: Optimizer,
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ) -> None:
"""
Save optimizer to checkpoint.
diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py
index 0df9d84159f9..68c6221ec809 100644
--- a/colossalai/booster/mixed_precision/__init__.py
+++ b/colossalai/booster/mixed_precision/__init__.py
@@ -6,16 +6,22 @@
from .mixed_precision_base import MixedPrecision
__all__ = [
- 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
- 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision'
+ "MixedPrecision",
+ "mixed_precision_factory",
+ "FP16_Apex_MixedPrecision",
+ "FP16_Torch_MixedPrecision",
+ "FP32_MixedPrecision",
+ "BF16_MixedPrecision",
+ "FP8_MixedPrecision",
+ "FP16NaiveMixedPrecision",
]
_mixed_precision_mapping = {
- 'fp16': FP16TorchMixedPrecision,
- 'fp16_apex': FP16ApexMixedPrecision,
- 'fp16_naive': FP16NaiveMixedPrecision,
- 'bf16': BF16MixedPrecision,
- 'fp8': FP8MixedPrecision
+ "fp16": FP16TorchMixedPrecision,
+ "fp16_apex": FP16ApexMixedPrecision,
+ "fp16_naive": FP16NaiveMixedPrecision,
+ "bf16": BF16MixedPrecision,
+ "fp8": FP8MixedPrecision,
}
@@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
return _mixed_precision_mapping[mixed_precision_type]()
else:
raise ValueError(
- f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
+ f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
)
diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py
index e184271e932a..2fa7b54cdd30 100644
--- a/colossalai/booster/mixed_precision/fp16_apex.py
+++ b/colossalai/booster/mixed_precision/fp16_apex.py
@@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision):
max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
"""
- def __init__(self,
- opt_level: Optional[str] = "O1",
- cast_model_type: torch.dtype = None,
- patch_torch_functions: bool = None,
- keep_batchnorm_fp32: Union[bool, str] = None,
- master_weights: bool = None,
- loss_scale: Union[float, str] = None,
- cast_model_outputs: Any = None,
- num_losses: Optional[int] = 1,
- verbosity: int = 1,
- min_loss_scale: float = None,
- max_loss_scale: float = 2.**24) -> None:
+ def __init__(
+ self,
+ opt_level: Optional[str] = "O1",
+ cast_model_type: torch.dtype = None,
+ patch_torch_functions: bool = None,
+ keep_batchnorm_fp32: Union[bool, str] = None,
+ master_weights: bool = None,
+ loss_scale: Union[float, str] = None,
+ cast_model_outputs: Any = None,
+ num_losses: Optional[int] = 1,
+ verbosity: int = 1,
+ min_loss_scale: float = None,
+ max_loss_scale: float = 2.0**24,
+ ) -> None:
pass
diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py
index 5d0d815257f3..e5624a9d7477 100644
--- a/colossalai/booster/mixed_precision/fp16_naive.py
+++ b/colossalai/booster/mixed_precision/fp16_naive.py
@@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision):
verbose(bool): if set to `True`, will print debug info.
"""
- def __init__(self,
- log_num_zeros_in_grad: bool,
- initial_scale: int,
- growth_factor: int,
- backoff_factor: float,
- hysteresis: int,
- max_scale: int,
- verbose: bool = None) -> None:
+ def __init__(
+ self,
+ log_num_zeros_in_grad: bool,
+ initial_scale: int,
+ growth_factor: int,
+ backoff_factor: float,
+ hysteresis: int,
+ max_scale: int,
+ verbose: bool = None,
+ ) -> None:
pass
diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py
index 26fd92bd50b8..7dce6e6da33e 100644
--- a/colossalai/booster/mixed_precision/fp16_torch.py
+++ b/colossalai/booster/mixed_precision/fp16_torch.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -9,7 +9,7 @@
from .mixed_precision_base import MixedPrecision
-__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
+__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper):
@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
calls that may cause the scale to increase. Default: 2000.
"""
- def __init__(self,
- optim: Optimizer,
- init_scale: float = 2.**16,
- growth_factor: float = 2.0,
- backoff_factor: float = 0.5,
- growth_interval: int = 2000) -> None:
+ def __init__(
+ self,
+ optim: Optimizer,
+ init_scale: float = 2.0**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ ) -> None:
super().__init__(optim)
- self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval)
+ self.scaler = torch.cuda.amp.GradScaler(
+ init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ )
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
@@ -60,12 +64,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2.0,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> None:
+ def clip_grad_by_norm(
+ self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs,
+ ) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
@@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision):
calls that may cause the scale to increase. Default: 2000.
"""
- def __init__(self,
- init_scale: float = 2.**16,
- growth_factor: float = 2.0,
- backoff_factor: float = 0.5,
- growth_interval: int = 2000) -> None:
+ def __init__(
+ self,
+ init_scale: float = 2.0**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ ) -> None:
super().__init__()
- self.torch_amp_kwargs = dict(init_scale=init_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval)
-
- def configure(self,
- model: nn.Module,
- optimizer: Optional[Optimizer] = None,
- criterion: Optional[Callable] = None,
- ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
+ self.torch_amp_kwargs = dict(
+ init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ )
+
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py
index f48bf38bd724..62f3708fc629 100644
--- a/colossalai/booster/plugin/__init__.py
+++ b/colossalai/booster/plugin/__init__.py
@@ -4,11 +4,12 @@
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
-__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']
+__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
import torch
from packaging import version
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
from .torch_fsdp_plugin import TorchFSDPPlugin
- __all__.append('TorchFSDPPlugin')
+
+ __all__.append("TorchFSDPPlugin")
diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py
index d5da5938bfd9..d2dd00453e32 100644
--- a/colossalai/booster/plugin/dp_plugin_base.py
+++ b/colossalai/booster/plugin/dp_plugin_base.py
@@ -10,25 +10,19 @@
class DPPluginBase(Plugin):
- """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.
- """
+ """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation."""
def __init__(self) -> None:
super().__init__()
- assert dist.is_initialized(
- ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
+ assert (
+ dist.is_initialized()
+ ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment"
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
- def prepare_dataloader(self,
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- **kwargs):
+ def prepare_dataloader(
+ self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ ):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
@@ -60,11 +54,13 @@ def seed_worker(worker_id):
torch.manual_seed(worker_seed)
random.seed(worker_seed)
- return DataLoader(dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs)
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index de03ba27bfda..83a00d4ee229 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -27,14 +27,13 @@
from .dp_plugin_base import DPPluginBase
-__all__ = ['GeminiPlugin']
+__all__ = ["GeminiPlugin"]
-SUPPORTED_PRECISION = ['fp16', 'bf16']
-PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
+SUPPORTED_PRECISION = ["fp16", "bf16"]
+PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
class GeminiCheckpointIO(GeneralCheckpointIO):
-
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
@@ -74,13 +73,15 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
super().load_unsharded_optimizer(optimizer, checkpoint)
- def save_sharded_model(self,
- model: GeminiDDP,
- checkpoint_path: str,
- gather_dtensor: bool = False,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
+ def save_sharded_model(
+ self,
+ model: GeminiDDP,
+ checkpoint_path: str,
+ gather_dtensor: bool = False,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
@@ -97,34 +98,37 @@ def save_sharded_model(self,
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint_path,
- index_file=index_file,
- base_filename=weights_name,
- is_master=is_master,
- use_safetensors=use_safetensors)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=is_master,
+ use_safetensors=use_safetensors,
+ )
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
-
- def load_sharded_model(self,
- model: GeminiDDP,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False):
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_sharded_model(
+ self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
+ ):
"""
Load shard model, load model from multiple files.
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
- def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
- size_per_shard: int):
+ def save_sharded_optimizer(
+ self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
@@ -153,20 +157,24 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=is_master,
- use_safetensors=False)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=is_master,
+ use_safetensors=False,
+ )
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
@@ -185,8 +193,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa
# Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups)
@@ -274,11 +284,11 @@ def __init__(
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
- shard_param_frac: float = 1.0, # only for static placement
- offload_optim_frac: float = 0.0, # only for static placement
- offload_param_frac: float = 0.0, # only for static placement
- warmup_non_model_data_ratio: float = 0.8, # only for auto placement
- steady_cuda_cap_ratio: float = 0.9, # only for auto placement
+ shard_param_frac: float = 1.0, # only for static placement
+ offload_optim_frac: float = 0.0, # only for static placement
+ offload_param_frac: float = 0.0, # only for static placement
+ warmup_non_model_data_ratio: float = 0.8, # only for auto placement
+ steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
@@ -300,7 +310,7 @@ def __init__(
verbose: bool = False,
) -> None:
super().__init__()
- assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
+ assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
@@ -319,16 +329,20 @@ def __init__(
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
- self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
- self.optim_kwargs = dict(initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- max_norm=max_norm,
- norm_type=norm_type)
+ self.zero_optim_config = dict(
+ gpu_margin_mem_ratio=gpu_margin_mem_ratio,
+ )
+ self.optim_kwargs = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ )
self.verbose = verbose
def support_no_sync(self) -> bool:
@@ -344,7 +358,7 @@ def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -354,7 +368,6 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
-
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
@@ -368,13 +381,10 @@ def configure(
# wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
- if optimizer is not None and \
- not isinstance(optimizer, OptimizerWrapper):
- optimizer = GeminiOptimizer(optimizer,
- model.unwrap(),
- **self.zero_optim_config,
- **self.optim_kwargs,
- verbose=self.verbose)
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ optimizer = GeminiOptimizer(
+ optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
+ )
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index d15245523226..c1693fa8d3a1 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -37,10 +37,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper):
-
- def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
- ddp_config: dict, custom_policy: Policy) -> None:
-
+ def __init__(
+ self,
+ module: Module,
+ precision: str,
+ shard_config: ShardConfig,
+ dp_group: ProcessGroup,
+ use_ddp: bool,
+ ddp_config: dict,
+ custom_policy: Policy,
+ ) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
@@ -54,13 +60,14 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.shared_param_process_groups.append(
- self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
+ self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
+ )
# setting mixed_precision
self.mixed_precision = None
- if precision == 'fp16':
+ if precision == "fp16":
self.mixed_precision = torch.float16
- elif precision == 'bf16':
+ elif precision == "bf16":
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
@@ -123,22 +130,21 @@ def get_param_info(optim: Optimizer):
if optim is None:
return {}
- param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
+ param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
+ packed_group = {k: v for k, v in group.items() if k != "params"}
+ packed_group["params"] = []
- packed_group = {k: v for k, v in group.items() if k != 'params'}
- packed_group['params'] = []
-
- for param_id, param in enumerate(group['params'], start_index):
+ for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
- packed_group['params'].append(param_id)
- param_info['param2id'][id(param)] = param_id
- param_info['id2param'][param_id] = id(param)
- param_info['param2shape'][id(param)] = original_shape
+ packed_group["params"].append(param_id)
+ param_info["param2id"][id(param)] = param_id
+ param_info["id2param"][param_id] = id(param)
+ param_info["param2shape"][id(param)] = original_shape
- param_info['param_groups'].append(packed_group)
- start_index += len(group['params'])
+ param_info["param_groups"].append(packed_group)
+ start_index += len(group["params"])
return param_info
@@ -147,13 +153,12 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
- params = [p for p in group['params'] if p in model_params]
- new_param_groups.append({**group, 'params': params})
- optim.__setstate__({'param_groups': new_param_groups})
+ params = [p for p in group["params"] if p in model_params]
+ new_param_groups.append({**group, "params": params})
+ optim.__setstate__({"param_groups": new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper):
-
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline:
@@ -162,60 +167,87 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_in
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
-
- def __init__(self,
- optim: Optimizer,
- model: Module,
- use_pipeline: bool,
- param_info: OrderedDict,
- precision: str = 'fp16',
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0):
+ def __init__(
+ self,
+ optim: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ precision: str = "fp16",
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ ):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
- super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
- hysteresis, max_scale, max_norm)
+ super().__init__(
+ optim,
+ precision,
+ initial_scale,
+ min_scale,
+ growth_factor,
+ backoff_factor,
+ growth_interval,
+ hysteresis,
+ max_scale,
+ max_norm,
+ )
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
-
def __init__(
- self,
- optimizer: Optimizer,
- model: Module,
- use_pipeline: bool,
- param_info: OrderedDict,
- initial_scale: int = 2**16, # grad scaler config
- min_scale: int = 1,
- growth_factor: float = 2.,
- backoff_factor: float = .5,
- growth_interval: int = 2000,
- hysteresis: int = 2,
- max_scale: int = 2**24,
- clip_grad_norm: float = 0.0, # grad clipping
- verbose: bool = False,
- reduce_bucket_size: int = 1024 * 1024, # communication
- communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True,
- partition_grad: bool = False, # stage 2 flag
- cpu_offload: bool = False, # cpu offload
- dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
- tp_process_group: Optional[ProcessGroup] = None, # if using tp
- forced_dtype: Optional[torch.dtype] = None):
+ self,
+ optimizer: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ initial_scale: int = 2**16, # grad scaler config
+ min_scale: int = 1,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ hysteresis: int = 2,
+ max_scale: int = 2**24,
+ clip_grad_norm: float = 0.0, # grad clipping
+ verbose: bool = False,
+ reduce_bucket_size: int = 1024 * 1024, # communication
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ partition_grad: bool = False, # stage 2 flag
+ cpu_offload: bool = False, # cpu offload
+ dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
+ tp_process_group: Optional[ProcessGroup] = None, # if using tp
+ forced_dtype: Optional[torch.dtype] = None,
+ ):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
- super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
- hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
- overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
- forced_dtype)
+ super().__init__(
+ optimizer,
+ initial_scale,
+ min_scale,
+ growth_factor,
+ backoff_factor,
+ growth_interval,
+ hysteresis,
+ max_scale,
+ clip_grad_norm,
+ verbose,
+ reduce_bucket_size,
+ communication_dtype,
+ overlap_communication,
+ partition_grad,
+ cpu_offload,
+ dp_process_group,
+ tp_process_group,
+ forced_dtype,
+ )
class HybridParallelPlugin(PipelinePluginBase):
@@ -276,46 +308,47 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""
- def __init__(self,
- tp_size: int,
- pp_size: int,
- precision: str = 'fp16',
- zero_stage: int = 0,
- enable_all_optimization: bool = False,
- enable_fused_normalization: bool = False,
- enable_flash_attention: bool = False,
- enable_jit_fused: bool = False,
- enable_sequence_parallelism: bool = False,
- enable_sequence_overlap: bool = False,
- num_microbatches: Optional[int] = None,
- microbatch_size: Optional[int] = None,
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0,
- broadcast_buffers: bool = True,
- ddp_bucket_cap_mb: int = 25,
- find_unused_parameters: bool = False,
- check_reduction: bool = False,
- gradient_as_bucket_view: bool = False,
- static_graph: bool = False,
- zero_bucket_size_in_m: int = 12,
- cpu_offload: bool = False,
- communication_dtype: Optional[torch.dtype] = None,
- overlap_communication: bool = True,
- custom_policy: Policy = None) -> None:
-
+ def __init__(
+ self,
+ tp_size: int,
+ pp_size: int,
+ precision: str = "fp16",
+ zero_stage: int = 0,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_parallelism: bool = False,
+ enable_sequence_overlap: bool = False,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ broadcast_buffers: bool = True,
+ ddp_bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ zero_bucket_size_in_m: int = 12,
+ cpu_offload: bool = False,
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ custom_policy: Policy = None,
+ ) -> None:
super().__init__()
- assert dist.get_world_size() % (
- tp_size * pp_size
- ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
+ assert (
+ dist.get_world_size() % (tp_size * pp_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
- assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
+ assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.tp_size = tp_size
self.pp_size = pp_size
@@ -334,24 +367,28 @@ def __init__(self,
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
- assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
- assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
+ assert (
+ num_microbatches is not None or microbatch_size is not None
+ ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
+ assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
- self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
- num_microbatches=num_microbatches,
- microbatch_size=microbatch_size)
+ self.schedule = OneForwardOneBackwardSchedule(
+ self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
+ )
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
- self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
- pipeline_stage_manager=self.stage_manager,
- enable_tensor_parallelism=self.tp_size > 1,
- enable_all_optimization=self.enable_all_optimization,
- enable_fused_normalization=self.enable_fused_normalization,
- enable_flash_attention=self.enable_flash_attention,
- enable_jit_fused=self.enable_jit_fused,
- enable_sequence_parallelism=enable_sequence_parallelism,
- enable_sequence_overlap=enable_sequence_overlap)
+ self.shard_config = ShardConfig(
+ tensor_parallel_process_group=self.tp_group,
+ pipeline_stage_manager=self.stage_manager,
+ enable_tensor_parallelism=self.tp_size > 1,
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism,
+ enable_sequence_overlap=enable_sequence_overlap,
+ )
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
@@ -362,18 +399,22 @@ def __init__(self,
max_scale=max_scale,
)
- self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
- bucket_cap_mb=ddp_bucket_cap_mb,
- find_unused_parameters=find_unused_parameters,
- check_reduction=check_reduction,
- gradient_as_bucket_view=gradient_as_bucket_view,
- static_graph=static_graph)
+ self.ddp_config = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=ddp_bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
- self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
- communication_dtype=communication_dtype,
- overlap_communication=overlap_communication,
- cpu_offload=cpu_offload,
- partition_grad=(self.zero_stage == 2))
+ self.zero_config = dict(
+ reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(self.zero_stage == 2),
+ )
self.max_norm = max_norm
@@ -382,10 +423,10 @@ def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def supported_precisions(self) -> List[str]:
- return ['fp16', 'bf16', 'fp32']
+ return ["fp16", "bf16", "fp32"]
def control_device(self) -> bool:
return True
@@ -410,57 +451,67 @@ def configure(
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
- model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
- self.ddp_config, self.custom_policy)
+ model = HybridParallelModule(
+ model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
+ )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
- if self.precision in ['fp16', 'bf16']:
- optimizer = HybridParallelAMPOptimizer(optimizer,
- model,
- use_pipeline=self.enable_pipeline_parallelism,
- param_info=param_info,
- precision=self.precision,
- max_norm=self.max_norm,
- **self.amp_config)
- self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
- optimizer.master_to_working_map)
+ if self.precision in ["fp16", "bf16"]:
+ optimizer = HybridParallelAMPOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ precision=self.precision,
+ max_norm=self.max_norm,
+ **self.amp_config,
+ )
+ self.checkpoint_io.link_master_and_working_param(
+ optimizer.working_to_master_map, optimizer.master_to_working_map
+ )
else:
- optimizer = HybridParallelNaiveOptimizer(optimizer,
- model,
- use_pipeline=self.enable_pipeline_parallelism,
- param_info=param_info)
+ optimizer = HybridParallelNaiveOptimizer(
+ optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
+ )
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
- assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
- optimizer = HybridParallelZeroOptimizer(optimizer,
- model,
- use_pipeline=self.enable_pipeline_parallelism,
- param_info=param_info,
- dp_process_group=self.dp_group,
- tp_process_group=self.tp_group,
- verbose=True,
- clip_grad_norm=self.max_norm,
- **self.zero_config,
- **self.amp_config)
- self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
- optimizer._param_store.master_to_working_param)
+ assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
+ optimizer = HybridParallelZeroOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ dp_process_group=self.dp_group,
+ tp_process_group=self.tp_group,
+ verbose=True,
+ clip_grad_norm=self.max_norm,
+ **self.zero_config,
+ **self.amp_config,
+ )
+ self.checkpoint_io.link_master_and_working_param(
+ optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
+ )
return model, optimizer, criterion, dataloader, lr_scheduler
- def execute_pipeline(self,
- data_iter: Iterator,
- model: HybridParallelModule,
- criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
- HybridParallelZeroOptimizer]] = None,
- return_loss: bool = True,
- return_outputs: bool = False) -> dict:
- assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: HybridParallelModule,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[
+ Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
+ ] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> dict:
+ assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
- outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
- return_outputs)
+ outputs = self.schedule.forward_backward_step(
+ model, data_iter, criterion, optimizer, return_loss, return_outputs
+ )
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
@@ -468,15 +519,9 @@ def execute_pipeline(self,
model.sync_grads()
return outputs
- def prepare_dataloader(self,
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- **kwargs):
+ def prepare_dataloader(
+ self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ ):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
@@ -499,10 +544,9 @@ def prepare_dataloader(self,
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
- sampler = DistributedSampler(dataset,
- num_replicas=self.pg_mesh.size(DP_AXIS),
- rank=self.pg_mesh.coordinate(DP_AXIS),
- shuffle=shuffle)
+ sampler = DistributedSampler(
+ dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
+ )
# Deterministic dataloader
def seed_worker(worker_id):
@@ -511,14 +555,16 @@ def seed_worker(worker_id):
torch.manual_seed(worker_seed)
random.seed(worker_seed)
- return DataLoader(dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs)
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 9adb4beec9b9..86adee7fe226 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -1,14 +1,12 @@
import logging
import os
-import warnings
from functools import partial
from pathlib import Path
from types import MethodType
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
-from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
@@ -33,7 +31,7 @@
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
-__all__ = ['LowLevelZeroPlugin']
+__all__ = ["LowLevelZeroPlugin"]
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
@@ -42,17 +40,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
return x
-SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
+SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
-
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
- if precision == 'fp16':
+ if precision == "fp16":
self.dtype = torch.float16
- elif precision == 'bf16':
+ elif precision == "bf16":
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
@@ -74,7 +71,6 @@ def unwrap(self):
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
-
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process.
@@ -91,12 +87,14 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
- def save_sharded_optimizer(self,
- optimizer: OptimizerWrapper,
- checkpoint: str,
- gather_dtensor: bool = False,
- prefix: str = None,
- size_per_shard: int = 1024):
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = False,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -148,9 +146,11 @@ def save_sharded_optimizer(self,
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.
@@ -170,8 +170,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory."
+ )
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
@@ -181,9 +183,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
- if isinstance(v, torch.Tensor) and k != 'step':
- padding_size = (self.coordinator.world_size -
- v.numel() % self.coordinator.world_size) % self.coordinator.world_size
+ if isinstance(v, torch.Tensor) and k != "step":
+ padding_size = (
+ self.coordinator.world_size - v.numel() % self.coordinator.world_size
+ ) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
@@ -194,33 +197,39 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
sharded_optimizer_loading_epilogue(optimizer)
- def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
- use_safetensors: bool):
+ def save_unsharded_model(
+ self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
+ ):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint_path: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_path: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
assert isinstance(model, LowLevelZeroModel)
- super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
- use_safetensors)
+ super().save_sharded_model(
+ model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
+ )
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
model.update_master_params()
- def load_sharded_model(self,
- model: LowLevelZeroModel,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False,
- load_sub_module: bool = True):
+ def load_sharded_model(
+ self,
+ model: LowLevelZeroModel,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
@@ -264,7 +273,7 @@ class LowLevelZeroPlugin(DPPluginBase):
def __init__(
self,
stage: int = 1,
- precision: str = 'fp16',
+ precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
@@ -281,9 +290,9 @@ def __init__(
verbose: bool = False,
) -> None:
super().__init__()
- assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
- assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
- assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
+ assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
+ assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
+ assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
@@ -319,7 +328,7 @@ def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -329,15 +338,13 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
-
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
- if optimizer is not None and \
- not isinstance(optimizer, OptimizerWrapper):
- optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
- **self.zero_optim_kwargs,
- verbose=self.verbose)
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
+ optimizer, **self.zero_optim_kwargs, verbose=self.verbose
+ )
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py
index fb21e57f41f7..4e570cbe8abc 100644
--- a/colossalai/booster/plugin/plugin_base.py
+++ b/colossalai/booster/plugin/plugin_base.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@@ -9,11 +9,10 @@
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
-__all__ = ['Plugin']
+__all__ = ["Plugin"]
class Plugin(ABC):
-
@abstractmethod
def supported_devices(self) -> List[str]:
pass
@@ -51,33 +50,31 @@ def control_checkpoint_io(self) -> bool:
"""
Whether the plugin controls the checkpoint io
"""
- pass
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
- pass
@abstractmethod
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
- pass
@abstractmethod
- def prepare_dataloader(self,
- dataset: Dataset,
- batch_size: int,
- shuffle: bool = False,
- seed: int = 1024,
- drop_last: bool = False,
- pin_memory: bool = False,
- num_workers: int = 0,
- **kwargs):
+ def prepare_dataloader(
+ self,
+ dataset: Dataset,
+ batch_size: int,
+ shuffle: bool = False,
+ seed: int = 1024,
+ drop_last: bool = False,
+ pin_memory: bool = False,
+ num_workers: int = 0,
+ **kwargs,
+ ):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""
- pass
diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py
index f52844db082f..3d91eb95b409 100644
--- a/colossalai/booster/plugin/pp_plugin_base.py
+++ b/colossalai/booster/plugin/pp_plugin_base.py
@@ -9,13 +9,14 @@
class PipelinePluginBase(Plugin):
-
@abstractmethod
- def execute_pipeline(self,
- data_iter: Iterator,
- model: ModelWrapper,
- criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Optional[OptimizerWrapper] = None,
- return_loss: bool = True,
- return_outputs: bool = False) -> dict:
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: ModelWrapper,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> dict:
pass
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
index f3f779c88e42..30d34e7dd5e5 100644
--- a/colossalai/booster/plugin/torch_ddp_plugin.py
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -12,11 +12,10 @@
from .dp_plugin_base import DPPluginBase
-__all__ = ['TorchDDPPlugin']
+__all__ = ["TorchDDPPlugin"]
class TorchDDPCheckpointIO(GeneralCheckpointIO):
-
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
@@ -49,25 +48,29 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint_path: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_path: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
- def save_sharded_optimizer(self,
- optimizer: Optimizer,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024):
+ def save_sharded_optimizer(
+ self,
+ optimizer: Optimizer,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save optimizer to checkpoint but only on master process.
"""
@@ -76,7 +79,6 @@ def save_sharded_optimizer(self,
class TorchDDPModel(ModelWrapper):
-
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
@@ -109,20 +111,24 @@ class TorchDDPPlugin(DPPluginBase):
static_graph (bool, optional): Whether to use static graph. Defaults to False.
"""
- def __init__(self,
- broadcast_buffers: bool = True,
- bucket_cap_mb: int = 25,
- find_unused_parameters: bool = False,
- check_reduction: bool = False,
- gradient_as_bucket_view: bool = False,
- static_graph: bool = False) -> None:
+ def __init__(
+ self,
+ broadcast_buffers: bool = True,
+ bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ ) -> None:
super().__init__()
- self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
- bucket_cap_mb=bucket_cap_mb,
- find_unused_parameters=find_unused_parameters,
- check_reduction=check_reduction,
- gradient_as_bucket_view=gradient_as_bucket_view,
- static_graph=static_graph)
+ self.ddp_kwargs = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
def support_no_sync(self) -> bool:
return True
@@ -131,13 +137,13 @@ def control_precision(self) -> bool:
return False
def supported_precisions(self) -> List[str]:
- return ['fp16', 'fp16_apex', 'bf16', 'fp8']
+ return ["fp16", "fp16_apex", "bf16", "fp8"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -156,8 +162,7 @@ def configure(
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
- if optimizer is not None and \
- not isinstance(optimizer, OptimizerWrapper):
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler
@@ -169,5 +174,5 @@ def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
- assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
+ assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()
diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py
index fb7b5baadd0c..d12b784b4fc1 100644
--- a/colossalai/booster/plugin/torch_fsdp_plugin.py
+++ b/colossalai/booster/plugin/torch_fsdp_plugin.py
@@ -1,13 +1,13 @@
import warnings
from pathlib import Path
-from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
@@ -31,11 +31,10 @@
from .dp_plugin_base import DPPluginBase
-__all__ = ['TorchFSDPPlugin']
+__all__ = ["TorchFSDPPlugin"]
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
-
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
@@ -69,26 +68,36 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
- def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
- size_per_shard: int, use_safetensors: bool):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ prefix: Optional[str],
+ size_per_shard: int,
+ use_safetensors: bool,
+ ):
"""
Save model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
- def load_sharded_model(self,
- model: nn.Module,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False,
- load_sub_module: bool = True):
+ def load_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
- def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
- size_per_shard: int):
+ def save_sharded_optimizer(
+ self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
"""
Save optimizer to checkpoint but only on master process.
"""
@@ -109,7 +118,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
class TorchFSDPModel(ModelWrapper):
-
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
@@ -119,7 +127,6 @@ def unwrap(self):
class FSDPOptimizerWrapper(OptimizerWrapper):
-
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
@@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
- if version.parse(torch.__version__) >= version.parse('1.12.0'):
+ if version.parse(torch.__version__) >= version.parse("1.12.0"):
def __init__(
self,
@@ -162,15 +169,18 @@ def __init__(
sync_module_states: bool = False,
):
super().__init__()
- self.fsdp_kwargs = dict(process_group=process_group,
- sharding_strategy=sharding_strategy,
- cpu_offload=cpu_offload,
- auto_wrap_policy=auto_wrap_policy,
- backward_prefetch=backward_prefetch,
- mixed_precision=mixed_precision,
- ignored_modules=ignored_modules,
- param_init_fn=param_init_fn,
- sync_module_states=sync_module_states)
+ self.fsdp_kwargs = dict(
+ process_group=process_group,
+ sharding_strategy=sharding_strategy,
+ cpu_offload=cpu_offload,
+ auto_wrap_policy=auto_wrap_policy,
+ backward_prefetch=backward_prefetch,
+ mixed_precision=mixed_precision,
+ ignored_modules=ignored_modules,
+ param_init_fn=param_init_fn,
+ sync_module_states=sync_module_states,
+ )
+
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@@ -184,13 +194,13 @@ def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
- return ['fp16', 'bf16']
+ return ["fp16", "bf16"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
+ return ["cuda"]
def configure(
self,
@@ -200,14 +210,13 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
-
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
- 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
+ "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py
index e1aa6543ef39..19b61730bded 100644
--- a/colossalai/checkpoint_io/__init__.py
+++ b/colossalai/checkpoint_io/__init__.py
@@ -3,4 +3,4 @@
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
-__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
+__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
index baff24e1cb25..f8ce8f4e5210 100644
--- a/colossalai/checkpoint_io/checkpoint_io_base.py
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -11,7 +11,7 @@
from .utils import has_index_file
-__all__ = ['CheckpointIO']
+__all__ = ["CheckpointIO"]
class CheckpointIO(ABC):
@@ -61,10 +61,9 @@ class CheckpointIO(ABC):
# ======================================
# Public methods
# ======================================
- def load_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- strict: bool = True) -> Union[nn.Module, ModelWrapper]:
+ def load_model(
+ self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
+ ) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@@ -98,14 +97,16 @@ def load_model(self,
return origin_model
- def save_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- shard: bool = False,
- gather_dtensor: bool = True,
- prefix: str = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False):
+ def save_model(
+ self,
+ model: Union[nn.Module, ModelWrapper],
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
Save model to checkpoint.
@@ -157,7 +158,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No
if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error
- raise ValueError(f'Cannot find index file in {checkpoint}')
+ raise ValueError(f"Cannot find index file in {checkpoint}")
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
@@ -165,13 +166,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
- def save_optimizer(self,
- optimizer: Optimizer,
- checkpoint: str,
- shard: bool = False,
- gather_dtensor=True,
- prefix: str = None,
- size_per_shard: int = 1024):
+ def save_optimizer(
+ self,
+ optimizer: Optimizer,
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor=True,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@@ -207,7 +210,6 @@ def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: boo
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
- pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
@@ -220,11 +222,17 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
- pass
@abstractmethod
- def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
- size_per_shard: int, use_safetensors: bool):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ prefix: Optional[str],
+ size_per_shard: int,
+ use_safetensors: bool,
+ ):
"""
Save model to sharded checkpoint.
@@ -236,7 +244,6 @@ def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor:
size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
"""
- pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
@@ -249,7 +256,6 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
"""
- pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
@@ -265,7 +271,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
"""
- pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
@@ -276,11 +281,11 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
- pass
@abstractmethod
- def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
- size_per_shard: int):
+ def save_sharded_optimizer(
+ self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
"""
Save optimizer to sharded checkpoint.
@@ -291,7 +296,6 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
- pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
@@ -303,7 +307,6 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gathe
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
"""
- pass
# ============================================
# methods for loading and saving lr scheduler
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index faaf1d22722a..b0e593e90d8c 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -3,9 +3,8 @@
import os
from functools import reduce
from pathlib import Path
-from typing import Iterator, Optional, OrderedDict, Tuple
+from typing import Optional
-import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
@@ -16,7 +15,6 @@
from .utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
- get_shard_filename,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
@@ -33,7 +31,7 @@
unwrap_optimizer,
)
-__all__ = ['GeneralCheckpointIO']
+__all__ = ["GeneralCheckpointIO"]
class GeneralCheckpointIO(CheckpointIO):
@@ -70,8 +68,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory."
+ )
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
@@ -123,19 +123,23 @@ def save_sharded_optimizer(
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
- total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=True,
- use_safetensors=False)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=sharded_state,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=True,
+ use_safetensors=False,
+ )
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
@@ -150,13 +154,15 @@ def save_unsharded_optimizer(
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint_path: str,
- gather_dtensor: bool = False,
- prefix: Optional[str] = None,
- max_shard_size: int = 1024,
- use_safetensors: bool = False):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_path: str,
+ gather_dtensor: bool = False,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
@@ -175,26 +181,32 @@ def save_sharded_model(self,
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint_path,
- index_file=index_file,
- base_filename=weights_name,
- is_master=True,
- use_safetensors=use_safetensors)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=True,
+ use_safetensors=use_safetensors,
+ )
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
- logging.info(f"The model is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
-
- def load_sharded_model(self,
- model: nn.Module,
- checkpoint_index_file: Path,
- strict: bool = False,
- use_safetensors: bool = False,
- load_sub_module: bool = True):
+ logging.info(
+ f"The model is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
"""
load shard model, load model from multiple files
"""
@@ -219,7 +231,11 @@ def load_sharded_model(self,
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
- error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
- '"{}"'.format(k) for k in missing_keys))
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- self.__class__.__name__, "\n\t".join(error_msgs)))
+ error_msgs = "Missing key(s) in state_dict: {}. ".format(
+ ", ".join('"{}"'.format(k) for k in missing_keys)
+ )
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
+ self.__class__.__name__, "\n\t".join(error_msgs)
+ )
+ )
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 270fd8564754..18c59a880dd6 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -1,5 +1,4 @@
import copy
-import gc
import logging
import os
from pathlib import Path
@@ -35,9 +34,9 @@
)
try:
- from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
- _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
+ _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class HybridParallelCheckpointIO(GeneralCheckpointIO):
@@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
"""
- def __init__(self,
- dp_group: ProcessGroup,
- pp_group: ProcessGroup,
- tp_group: ProcessGroup,
- zero_stage: int,
- verbose: bool = True) -> None:
+ def __init__(
+ self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ verbose: bool = True,
+ ) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
@@ -68,17 +69,16 @@ def __init__(self,
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
- self.use_zero = (zero_stage > 0)
+ self.use_zero = zero_stage > 0
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
self.coordinator = DistCoordinator()
@staticmethod
- def _model_sharder(model: nn.Module,
- prefix: str = '',
- keep_vars: bool = False,
- size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
+ def _model_sharder(
+ model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
+ ) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
@@ -103,8 +103,10 @@ def _model_sharder(model: nn.Module,
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
- if getattr(model.__class__, "get_extra_state",
- torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
@@ -114,20 +116,20 @@ def _model_sharder(model: nn.Module,
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
- def _optimizer_sharder(optimizer: OptimizerWrapper,
- use_zero: bool,
- dp_group: ProcessGroup,
- tp_group: ProcessGroup,
- master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
- size_per_shard: int = 1024):
-
+ def _optimizer_sharder(
+ optimizer: OptimizerWrapper,
+ use_zero: bool,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
+ size_per_shard: int = 1024,
+ ):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
for param, state in optimizer.optim.state.items():
-
if param is None:
continue
@@ -136,15 +138,17 @@ def _optimizer_sharder(optimizer: OptimizerWrapper,
else:
working_param = param
- param_id = param_info['param2id'][id(working_param)]
- original_shape = param_info['param2shape'][id(working_param)]
- state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
- working_param,
- original_shape=original_shape,
- dp_group=dp_group,
- tp_group=tp_group,
- use_zero=use_zero,
- inplace=False)
+ param_id = param_info["param2id"][id(working_param)]
+ original_shape = param_info["param2shape"][id(working_param)]
+ state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=dp_group,
+ tp_group=tp_group,
+ use_zero=use_zero,
+ inplace=False,
+ )
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
@@ -153,13 +157,15 @@ def _optimizer_sharder(optimizer: OptimizerWrapper,
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
- def save_sharded_model(self,
- model: nn.Module,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False) -> None:
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -194,24 +200,28 @@ def save_sharded_model(self,
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
- control_saving = (self.tp_rank == 0)
+ control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=weights_name,
- is_master=control_saving,
- use_safetensors=use_safetensors)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ )
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -228,15 +238,19 @@ def save_sharded_model(self,
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=weights_name,
- is_master=control_saving,
- use_safetensors=use_safetensors,
- use_pp_format=True)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ use_pp_format=True,
+ )
if control_saving:
- assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ assert (
+ self.dp_rank == 0 and self.tp_rank == 0
+ ), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
@@ -259,9 +273,11 @@ def save_sharded_model(self,
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {final_index_file_path}.")
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
@@ -305,11 +321,9 @@ def _load(name: str):
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
- load_state_dict_into_model(model,
- state_dict,
- missing_keys=missing_keys,
- strict=strict,
- load_sub_module=True)
+ load_state_dict_into_model(
+ model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
+ )
loaded_file.add(filename)
# Load parameters.
@@ -319,15 +333,17 @@ def _load(name: str):
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
- non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
+ non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
- if getattr(model.__class__, "get_extra_state",
- torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
@@ -352,12 +368,14 @@ def _load(name: str):
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
- def save_sharded_optimizer(self,
- optimizer: OptimizerWrapper,
- checkpoint: str,
- gather_dtensor: bool = True,
- prefix: Optional[str] = None,
- size_per_shard: int = 1024):
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
@@ -393,18 +411,21 @@ def save_sharded_optimizer(self,
dp_group=self.dp_group,
tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map,
- size_per_shard=size_per_shard)
+ size_per_shard=size_per_shard,
+ )
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
- control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
+ control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=control_saving)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ )
if control_saving:
# Store param groups.
@@ -415,9 +436,11 @@ def save_sharded_optimizer(self,
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose:
- logging.info(f"The optimizer is going to be split to checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -433,15 +456,19 @@ def save_sharded_optimizer(self,
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
- total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
- checkpoint=checkpoint,
- index_file=index_file,
- base_filename=states_name,
- is_master=control_saving,
- use_pp_format=True)
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True,
+ )
if control_saving:
- assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ assert (
+ self.dp_rank == 0 and self.tp_rank == 0
+ ), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
@@ -451,7 +478,6 @@ def save_sharded_optimizer(self,
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
-
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
@@ -470,9 +496,11 @@ def save_sharded_optimizer(self,
rmtree(tmp_index_file_folder)
if self.verbose:
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {final_index_file_path}.")
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
@@ -484,20 +512,21 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f
prefix (str): Not used.
"""
- def _get_param_id_from_optimizer_param(param: torch.Tensor,
- master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
- return optimizer.param_info['param2id'][id(working_param)]
+ return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
for pg in optimizer.optim.param_groups:
- for param in pg['params']:
+ for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
id_map[param_id] = param
@@ -505,28 +534,30 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor,
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
- weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
- raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
- Lacking param group file under current directory.')
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
- new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
+ new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
- optimizer.optim.__dict__.update({'param_groups': updated_groups})
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
- for param in pg['params']:
+ for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
@@ -550,12 +581,10 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor,
working_param = self.master_to_working_map[id(param)]
else:
working_param = param
- original_shape = optimizer.param_info['param2shape'][id(working_param)]
- sharded_state = self.shard_from_complete_optimizer_state(state,
- current_shape=working_param.shape,
- original_shape=original_shape,
- device=device,
- inplace=True)
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(
+ state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
+ )
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
@@ -585,8 +614,11 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
- def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
- master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
+ def link_master_and_working_param(
+ self,
+ working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
+ master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
+ ):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
@@ -604,7 +636,8 @@ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, t
self.working_to_master_map[k] = v
else:
raise ValueError(
- f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
+ f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
+ )
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
@@ -614,12 +647,19 @@ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, t
self.master_to_working_map[k] = v
else:
raise ValueError(
- f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
+ f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
+ )
@staticmethod
- def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
- dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
- inplace: bool) -> OrderedDict:
+ def gather_from_sharded_optimizer_state(
+ state: OrderedDict,
+ param: torch.Tensor,
+ original_shape: torch.Size,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ use_zero: bool,
+ inplace: bool,
+ ) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
@@ -641,14 +681,13 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor,
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
- if isinstance(v, torch.Tensor) and k != 'step':
-
+ if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
- v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
@@ -661,9 +700,14 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor,
return state_
- def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
- original_shape: torch.Size, device: torch.device,
- inplace: bool) -> OrderedDict:
+ def shard_from_complete_optimizer_state(
+ self,
+ state: OrderedDict,
+ current_shape: torch.Size,
+ original_shape: torch.Size,
+ device: torch.device,
+ inplace: bool,
+ ) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
@@ -681,8 +725,7 @@ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape:
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
- if isinstance(v, torch.Tensor) and k != 'step':
-
+ if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:
diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py
index 388cf3fbe9bb..da12c146f2c3 100644
--- a/colossalai/checkpoint_io/index_file.py
+++ b/colossalai/checkpoint_io/index_file.py
@@ -6,7 +6,7 @@
from .utils import is_dtensor_checkpoint
-__all__ = ['CheckpointIndexFile']
+__all__ = ["CheckpointIndexFile"]
class CheckpointIndexFile:
@@ -50,7 +50,7 @@ def load(self, json_path: str):
json_path (str): path to the json file.
"""
# load the json file
- with open(json_path, 'r') as f:
+ with open(json_path, "r") as f:
index = json.load(f)
# assign attributes if exists
@@ -75,7 +75,7 @@ def export(self, json_path: str):
index["weight_map"] = self.weight_map
# export the index file
- with open(json_path, 'w') as f:
+ with open(json_path, "w") as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 664ac63e45ac..c22b76dd46f7 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -1,5 +1,4 @@
# coding=utf-8
-import copy
import os
import re
from collections import abc as container_abcs
@@ -12,7 +11,7 @@
import torch.nn as nn
from torch.optim import Optimizer
-from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.interface import OptimizerWrapper
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@@ -55,7 +54,6 @@ def is_safetensors_available() -> bool:
bool: whether safetensors is available.
"""
try:
- import safetensors
return True
except ImportError:
return False
@@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a dtensor checkpoint.
"""
- if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
+ if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
return True
else:
return False
@@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a safetensor checkpoint.
"""
- if checkpoint_file_path.endswith('.safetensors'):
+ if checkpoint_file_path.endswith(".safetensors"):
return True
else:
return False
@@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
partition_dim = dim
break
if partition_dim is not None:
- assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
- f"The parameter isn't evenly distributed among tensor parallel group: \
+ assert (
+ original_shape[partition_dim] == tp_size * current_shape[partition_dim]
+ ), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim
@@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
- '''
+ """
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
- '''
+ """
unwrapped_optim = optimizer.optim
return unwrapped_optim
class StateDictSharder:
-
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
-
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
@@ -159,13 +156,11 @@ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[Ordere
return ret_block, ret_block_size
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
-
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
-
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
@@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
return param_
-def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
- checkpoint: str,
- index_file: "CheckpointIndexFile",
- base_filename: str,
- is_master: bool,
- use_safetensors: bool = False,
- use_pp_format: bool = False) -> int:
- '''
+def save_state_dict_shards(
+ sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
+ checkpoint: str,
+ index_file: "CheckpointIndexFile",
+ base_filename: str,
+ is_master: bool,
+ use_safetensors: bool = False,
+ use_pp_format: bool = False,
+) -> int:
+ """
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
@@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
Returns:
int: the total size of shards
- '''
+ """
total_size = 0
shard_filenames = []
@@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
- states = state_dict['state']
+ states = state_dict["state"]
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
@@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
- assert checkpoint_file_path.endswith('.safetensors'), \
- "safetensors only supports .safetensors suffix for checkpoint file."
+ assert checkpoint_file_path.endswith(
+ ".safetensors"
+ ), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
+
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
@@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
torch.save(param_groups, group_file_path)
-def clean_folder(checkpoint_path: str,
- weights_name: str,
- shard_filenames: List[str],
- is_master: bool = True,
- use_pp_format: bool = False):
+def clean_folder(
+ checkpoint_path: str,
+ weights_name: str,
+ shard_filenames: List[str],
+ is_master: bool = True,
+ use_pp_format: bool = False,
+):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
@@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str,
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
- if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
- and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
+ if (
+ filename.startswith(weights_no_suffix)
+ and os.path.isfile(full_filename)
+ and filename not in shard_filenames
+ and reg.fullmatch(filename_no_suffix) is not None
+ ):
os.remove(full_filename)
@@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
- output_root_path = root_path.joinpath('dtensor')
+ output_root_path = root_path.joinpath("dtensor")
# create directory
output_root_path.mkdir(exist_ok=True)
@@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
# update the weight map
# * means all shards
- ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
+ ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
@@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
str: checkpoint file suffix.
"""
if use_safetensors:
- return '.safetensors'
+ return ".safetensors"
else:
- return '.bin'
+ return ".bin"
-def generate_checkpoint_shard_file_name(index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None) -> str:
+def generate_checkpoint_shard_file_name(
+ index: int, total_number: int, use_safetensors: bool, prefix: str = None
+) -> str:
"""
Generate checkpoint shard file name.
@@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
- return f'{param_name}.{index}.{suffix}'
+ return f"{param_name}.{index}.{suffix}"
# ========================================
@@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
+
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
- f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
+ f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
+ )
return safe_load_file(checkpoint_file)
else:
- return torch.load(checkpoint_file, map_location=torch.device('cpu'))
+ return torch.load(checkpoint_file, map_location=torch.device("cpu"))
-def load_state_dict_into_model(model: nn.Module,
- state_dict: torch.Tensor,
- missing_keys: List,
- strict: bool = False,
- load_sub_module: bool = True):
+def load_state_dict_into_model(
+ model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
+):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
@@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module,
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
+ metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
@@ -560,10 +564,12 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True)
if strict:
if len(unexpected_keys) > 0:
- error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
- '"{}"'.format(k) for k in unexpected_keys))
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- model.__class__.__name__, "\n\t".join(error_msgs)))
+ error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
+ ", ".join('"{}"'.format(k) for k in unexpected_keys)
+ )
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
+ )
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
@@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
- saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
+ saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
if not isinstance(saved_groups, List):
- raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
+ raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
@@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups")
- param_lens = (len(g['params']) for g in param_groups)
- saved_lens = (len(g['params']) for g in saved_groups)
+ param_lens = (len(g["params"]) for g in param_groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
- raise ValueError("loaded state dict contains a parameter group "
- "that doesn't match the size of optimizer's group")
+ raise ValueError(
+ "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
+ )
# Creating mapping from id to parameters.
id_map = {
- old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
- )), chain.from_iterable((g['params'] for g in param_groups)))
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in param_groups)),
+ )
}
# Update parameter groups, setting their 'params' value.
def update_group(group, new_group):
- new_group['params'] = group['params']
+ new_group["params"] = group["params"]
return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
- optimizer.__dict__.update({'param_groups': updated_groups})
+ optimizer.__dict__.update({"param_groups": updated_groups})
return id_map
@@ -628,7 +638,7 @@ def cast(param, value, key=None):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
- if (key != "step"):
+ if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
@@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
"""
# Do the cleaning up as in src code of Pytorch.
- optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
- optimizer.defaults.setdefault('differentiable', False)
+ optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
+ optimizer.defaults.setdefault("differentiable", False)
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
@@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
- index_files = list(checkpoint_path.glob('*.index.*json'))
+ index_files = list(checkpoint_path.glob("*.index.*json"))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
- assert len(
- index_files
- ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
+ assert (
+ len(index_files) == 1
+ ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
if len(index_files) == 1:
return True, index_files[0]
else:
return False, None
else:
- raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
+ raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def load_state_dict(checkpoint_file_path: Path):
@@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path):
dict: state dict.
"""
- assert not is_dtensor_checkpoint(checkpoint_file_path), \
- f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
+ assert not is_dtensor_checkpoint(
+ checkpoint_file_path
+ ), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
if is_safetensor_checkpoint(checkpoint_file_path):
- assert is_safetensors_available(), \
- f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
+ assert (
+ is_safetensors_available()
+ ), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
# load with safetensors
from safetensors import safe_open
+
state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys():
@@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
- return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
+ return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py
index 658e35e4c72e..c7cb19c19308 100644
--- a/colossalai/cli/__init__.py
+++ b/colossalai/cli/__init__.py
@@ -1,3 +1,3 @@
from .cli import cli
-__all__ = ['cli']
+__all__ = ["cli"]
diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py
index a86b32bb6a18..7c26ab6ade6c 100644
--- a/colossalai/cli/check/__init__.py
+++ b/colossalai/cli/check/__init__.py
@@ -1,11 +1,12 @@
import click
+
from .check_installation import check_installation
-__all__ = ['check']
+__all__ = ["check"]
@click.command(help="Check if Colossal-AI is correct based on the given option")
-@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly")
+@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly")
def check(installation):
if installation:
check_installation()
diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py
index 4a481f3bd122..772c513ffa06 100644
--- a/colossalai/cli/check/check_installation.py
+++ b/colossalai/cli/check/check_installation.py
@@ -9,7 +9,7 @@
def to_click_output(val):
# installation check output to understandable symbols for readability
- VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'}
+ VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"}
if val in VAL_TO_SYMBOL:
return VAL_TO_SYMBOL[val]
@@ -55,8 +55,8 @@ def check_installation():
else:
torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required])
- click.echo(f'#### Installation Report ####')
- click.echo(f'\n------------ Environment ------------')
+ click.echo(f"#### Installation Report ####")
+ click.echo(f"\n------------ Environment ------------")
click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}")
click.echo(f"PyTorch version: {to_click_output(torch_version)}")
click.echo(f"System CUDA version: {to_click_output(cuda_version)}")
@@ -69,7 +69,7 @@ def check_installation():
f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version."
)
- click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------')
+ click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------")
click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}")
click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}")
click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}")
@@ -81,7 +81,7 @@ def check_installation():
click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime")
click.echo(f"\n------------ Compatibility ------------")
- click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}')
+ click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}")
click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}")
click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}")
click.echo(f"")
@@ -106,12 +106,12 @@ def _is_compatible(versions):
return False
# split version into [major, minor, patch]
- versions = [version.split('.') for version in versions]
+ versions = [version.split(".") for version in versions]
for version in versions:
if len(version) == 2:
# x means unknown
- version.append('x')
+ version.append("x")
for idx, version_values in enumerate(zip(*versions)):
equal = len(set(version_values)) == 1
@@ -137,11 +137,11 @@ def _parse_colossalai_version():
# 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions)
# 2. X.X.X (when colossalai is not installed with CUDA extensions)
# where X represents an integer.
- colossalai_version = colossalai.__version__.split('+')[0]
+ colossalai_version = colossalai.__version__.split("+")[0]
try:
- torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0]
- cuda_version_for_aot_build = colossalai.__version__.split('cu')[1]
+ torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0]
+ cuda_version_for_aot_build = colossalai.__version__.split("cu")[1]
except:
torch_version_for_aot_build = None
cuda_version_for_aot_build = None
@@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed():
JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime.
"""
try:
- import colossalai._C.fused_optim
found_aot_cuda_ext = True
except ImportError:
found_aot_cuda_ext = False
@@ -175,14 +174,14 @@ def _check_torch_version():
# torch version can be of two formats
# - 1.13.1+cu113
# - 1.13.1.devxxx
- torch_version = torch.__version__.split('+')[0]
- torch_version = '.'.join(torch_version.split('.')[:3])
+ torch_version = torch.__version__.split("+")[0]
+ torch_version = ".".join(torch_version.split(".")[:3])
# get cuda version in pytorch build
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
- torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
+ torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}"
except:
torch_cuda_version = None
@@ -208,7 +207,7 @@ def _check_cuda_version():
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
- cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
+ cuda_version = f"{bare_metal_major}.{bare_metal_minor}"
except:
cuda_version = None
return cuda_version
diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py
index 0dea7c504957..0d94fe59f8ae 100644
--- a/colossalai/cli/cli.py
+++ b/colossalai/cli/cli.py
@@ -4,8 +4,7 @@
from .launcher import run
-class Arguments():
-
+class Arguments:
def __init__(self, arg_dict):
for k, v in arg_dict.items():
self.__dict__[k] = v
@@ -19,5 +18,5 @@ def cli():
cli.add_command(run)
cli.add_command(check)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli()
diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py
index 808e4e84574f..0f9ead6495db 100644
--- a/colossalai/cli/launcher/__init__.py
+++ b/colossalai/cli/launcher/__init__.py
@@ -5,56 +5,81 @@
from .run import launch_multi_processes
-@click.command(help="Launch distributed training on a single node or multiple nodes",
- context_settings=dict(ignore_unknown_options=True))
-@click.option("-H",
- "-host",
- "--host",
- type=str,
- default=None,
- help="the list of hostnames to launch in the format ,")
+@click.command(
+ help="Launch distributed training on a single node or multiple nodes",
+ context_settings=dict(ignore_unknown_options=True),
+)
+@click.option(
+ "-H",
+ "-host",
+ "--host",
+ type=str,
+ default=None,
+ help="the list of hostnames to launch in the format ,",
+)
@click.option(
"--hostfile",
type=str,
default=None,
- help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
-@click.option("--include",
- type=str,
- default=None,
- help="Specify computing devices to use during execution. String format is ,,"
- " only effective when used with --hostfile.")
+ help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname",
+)
+@click.option(
+ "--include",
+ type=str,
+ default=None,
+ help="Specify computing devices to use during execution. String format is ,,"
+ " only effective when used with --hostfile.",
+)
@click.option(
"--exclude",
type=str,
default=None,
- help=
- "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
- " only effective when used with --hostfile.")
-@click.option("--num_nodes",
- type=int,
- default=-1,
- help="Total number of worker nodes to use, only effective when used with --hostfile.")
+ help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
+ " only effective when used with --hostfile.",
+)
+@click.option(
+ "--num_nodes",
+ type=int,
+ default=-1,
+ help="Total number of worker nodes to use, only effective when used with --hostfile.",
+)
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
-@click.option("--master_port",
- type=int,
- default=29500,
- help="(optional) Port used by PyTorch distributed for communication during distributed training.")
-@click.option("--master_addr",
- type=str,
- default="127.0.0.1",
- help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
+@click.option(
+ "--master_port",
+ type=int,
+ default=29500,
+ help="(optional) Port used by PyTorch distributed for communication during distributed training.",
+)
+@click.option(
+ "--master_addr",
+ type=str,
+ default="127.0.0.1",
+ help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.",
+)
@click.option(
"--extra_launch_args",
type=str,
default=None,
- help=
- "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
- "This will be converted to --arg1=1 --arg2=2 during execution")
+ help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
+ "This will be converted to --arg1=1 --arg2=2 during execution",
+)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
-@click.argument('user_args', nargs=-1)
-def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
- master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
+@click.argument("user_args", nargs=-1)
+def run(
+ host: str,
+ hostfile: str,
+ num_nodes: int,
+ nproc_per_node: int,
+ include: str,
+ exclude: str,
+ master_addr: str,
+ master_port: int,
+ extra_launch_args: str,
+ ssh_port: int,
+ user_script: str,
+ user_args: str,
+) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include:
# run with hostfile excluding the hosts selected
colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
- if not user_script.endswith('.py'):
- click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
+ if not user_script.endswith(".py"):
+ click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
args_dict = locals()
diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py
index 2a6a111e4d72..684f64f59d28 100644
--- a/colossalai/cli/launcher/hostinfo.py
+++ b/colossalai/cli/launcher/hostinfo.py
@@ -1,5 +1,4 @@
import socket
-from typing import List
class HostInfo:
@@ -34,7 +33,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None:
"""
if port is None:
- port = 22 # no port specified, lets just use the ssh port
+ port = 22 # no port specified, lets just use the ssh port
# socket.getfqdn("127.0.0.1") does not return localhost
# on some users' machines
@@ -50,7 +49,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None:
return localaddrs == targetaddrs
def __str__(self):
- return f'hostname: {self.hostname}, port: {self.port}'
+ return f"hostname: {self.hostname}, port: {self.port}"
def __repr__(self):
return self.__str__()
diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py
index 85b241e96292..99c4db406844 100644
--- a/colossalai/cli/launcher/multinode_runner.py
+++ b/colossalai/cli/launcher/multinode_runner.py
@@ -7,8 +7,13 @@
from .hostinfo import HostInfo, HostInfoList
-def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
- send_conn: mp_connection.Connection, env: dict) -> None:
+def run_on_host(
+ hostinfo: HostInfo,
+ workdir: str,
+ recv_conn: mp_connection.Connection,
+ send_conn: mp_connection.Connection,
+ env: dict,
+) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False
- env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
+ env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()])
# keep listening until exit
while not finish:
# receive cmd
cmds = recv_conn.recv()
- if cmds == 'exit':
+ if cmds == "exit":
# exit from the loop
finish = True
break
@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
- send_conn.send('success')
+ send_conn.send("success")
except Exception as e:
click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
)
- send_conn.send('failure')
+ send_conn.send("failure")
# shutdown
send_conn.send("finish")
@@ -96,8 +101,7 @@ def send(self, hostinfo: HostInfo, cmd: str) -> None:
cmd (str): the command to execute
"""
- assert hostinfo.hostname in self.master_send_conns, \
- f'{hostinfo} is not found in the current connections'
+ assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
@@ -107,7 +111,7 @@ def stop_all(self) -> None:
"""
for hostname, conn in self.master_send_conns.items():
- conn.send('exit')
+ conn.send("exit")
def recv_from_all(self) -> dict:
"""
diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py
index d2d02811ac9d..7ca8ee90386c 100644
--- a/colossalai/cli/launcher/run.py
+++ b/colossalai/cli/launcher/run.py
@@ -12,7 +12,7 @@
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
-NODE_SEP = ','
+NODE_SEP = ","
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit()
- with open(hostfile_path, 'r') as fd:
+ with open(hostfile_path, "r") as fd:
device_pool = HostInfoList()
for line in fd.readlines():
line = line.strip()
- if line == '':
+ if line == "":
# skip empty lines
continue
@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
- '''Parse an inclusion or exclusion string and filter a hostfile dictionary.
+ """Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples:
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
- '''
+ """
# Ensure include/exclude are mutually exclusive
if include_str and exclude_str:
@@ -136,16 +136,16 @@ def _arg_dict_to_list(arg_dict):
for k, v in arg_dict.items():
if v:
- ret.append(f'--{k}={v}')
+ ret.append(f"--{k}={v}")
else:
- ret.append(f'--{k}')
+ ret.append(f"--{k}")
return ret
if extra_launch_args:
extra_launch_args_dict = dict()
- for arg in extra_launch_args.split(','):
- if '=' in arg:
- k, v = arg.split('=')
+ for arg in extra_launch_args.split(","):
+ if "=" in arg:
+ k, v = arg.split("=")
extra_launch_args_dict[k] = v
else:
extra_launch_args_dict[arg] = None
@@ -158,9 +158,14 @@ def _arg_dict_to_list(arg_dict):
if torch_version.minor < 9:
cmd = [
- sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
- f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
- f"--node_rank={node_rank}"
+ sys.executable,
+ "-m",
+ "torch.distributed.launch",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--master_addr={master_addr}",
+ f"--master_port={master_port}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
else:
# extra launch args for torch distributed launcher with torch >= 1.9
@@ -174,17 +179,24 @@ def _arg_dict_to_list(arg_dict):
if torch_version.minor < 10:
cmd = [
- sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
- f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
else:
cmd = [
- "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
+ "torchrun",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
- cmd = ' '.join(cmd)
+ cmd = " ".join(cmd)
return cmd
@@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None:
# run on local node if not hosts or hostfile is given
# add local node to host info list
active_device_pool = HostInfoList()
- localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
+ localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
active_device_pool.append(localhost_info)
# launch distributed processes
runner = MultiNodeRunner()
- curr_path = os.path.abspath('.')
+ curr_path = os.path.abspath(".")
# collect current path env
env = dict()
for k, v in os.environ.items():
# do not support multi-line env var
- if v and '\n' not in v:
+ if v and "\n" not in v:
env[k] = v
# establish remote connection
@@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None:
# execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool):
- cmd = get_launch_command(master_addr=args.master_addr,
- master_port=args.master_port,
- nproc_per_node=args.nproc_per_node,
- user_script=args.user_script,
- user_args=args.user_args,
- node_rank=node_id,
- num_nodes=len(active_device_pool),
- extra_launch_args=args.extra_launch_args)
+ cmd = get_launch_command(
+ master_addr=args.master_addr,
+ master_port=args.master_port,
+ nproc_per_node=args.nproc_per_node,
+ user_script=args.user_script,
+ user_args=args.user_args,
+ node_rank=node_id,
+ num_nodes=len(active_device_pool),
+ extra_launch_args=args.extra_launch_args,
+ )
runner.send(hostinfo=hostinfo, cmd=cmd)
# start training
diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py
index 44f571ca2501..b8176feb647b 100644
--- a/colossalai/cluster/__init__.py
+++ b/colossalai/cluster/__init__.py
@@ -3,4 +3,4 @@
from .process_group_manager import ProcessGroupManager
from .process_group_mesh import ProcessGroupMesh
-__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh']
+__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"]
diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py
index 8754baa19792..e35aca5f4d7e 100644
--- a/colossalai/cluster/device_mesh_manager.py
+++ b/colossalai/cluster/device_mesh_manager.py
@@ -10,13 +10,14 @@
@dataclass
class DeviceMeshInfo:
- '''
+ """
This class is used to store the information used to initialize the device mesh.
Args:
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
- '''
+ """
+
physical_ids: List[int]
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
@@ -24,16 +25,18 @@ def __post_init__(self):
if self.mesh_shape is not None:
world_size = len(self.physical_ids)
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
- assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
+ assert (
+ world_size == mesh_shape_numel
+ ), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
- '''
+ """
This method is used to initialize the device mesh.
Args:
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
- '''
+ """
# parse the device mesh info
physical_devices = device_mesh_info.physical_ids
physical_mesh = torch.tensor(physical_devices)
@@ -67,13 +70,13 @@ def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMe
Args:
name (str): name of the device mesh
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
- """
+ """
if name not in self.device_mesh_store:
device_mesh = initialize_device_mesh(device_mesh_info)
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
- raise ValueError(f'Device mesh {name} already exists.')
+ raise ValueError(f"Device mesh {name} already exists.")
def get(self, name: str) -> DeviceMesh:
"""
@@ -88,7 +91,7 @@ def get(self, name: str) -> DeviceMesh:
if name in self.device_mesh_store:
return self.device_mesh_store[name]
else:
- raise ValueError(f'Device mesh {name} does not exist.')
+ raise ValueError(f"Device mesh {name} does not exist.")
def destroy(self, name: str) -> None:
"""
@@ -103,7 +106,7 @@ def destroy(self, name: str) -> None:
dist.destroy_process_group(pg)
del self.device_mesh_store[name]
else:
- raise ValueError(f'Device mesh {name} does not exist.')
+ raise ValueError(f"Device mesh {name} does not exist.")
def destroy_all(self):
"""
diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py
index 3ee364ec3364..5b66e88717ba 100644
--- a/colossalai/cluster/dist_coordinator.py
+++ b/colossalai/cluster/dist_coordinator.py
@@ -36,12 +36,13 @@ class in the whole program.
"""
def __init__(self):
- assert dist.is_initialized(
- ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
+ assert (
+ dist.is_initialized()
+ ), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
- self._local_rank = os.environ.get('LOCAL_RANK', -1)
+ self._local_rank = os.environ.get("LOCAL_RANK", -1)
@property
def rank(self) -> int:
@@ -59,7 +60,9 @@ def _assert_local_rank_set(self):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
- assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
+ assert (
+ self.local_rank >= 0
+ ), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
@@ -183,7 +186,6 @@ def on_master_only(self, process_group: ProcessGroup = None):
# define an inner function
def decorator(func):
-
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py
index e52661846f3e..68106b503126 100644
--- a/colossalai/cluster/process_group_manager.py
+++ b/colossalai/cluster/process_group_manager.py
@@ -19,7 +19,7 @@ class ProcessGroupManager:
def __init__(self):
self.pg_store = dict()
- def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
+ def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.
@@ -36,7 +36,7 @@ def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl
self.pg_store[name] = pg
return pg
else:
- raise ValueError(f'Process group {name} already exists.')
+ raise ValueError(f"Process group {name} already exists.")
def get(self, name: str) -> ProcessGroup:
"""
@@ -51,7 +51,7 @@ def get(self, name: str) -> ProcessGroup:
if name in self.pg_store:
return self.pg_store[name]
else:
- raise ValueError(f'Process group {name} does not exist.')
+ raise ValueError(f"Process group {name} does not exist.")
def destroy(self, name: str) -> None:
"""
@@ -64,7 +64,7 @@ def destroy(self, name: str) -> None:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
- raise ValueError(f'Process group {name} does not exist.')
+ raise ValueError(f"Process group {name} does not exist.")
def destroy_all(self) -> None:
"""
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 623160003767..3885bc962561 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -94,7 +94,7 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
return np.unravel_index(rank, shape)
@staticmethod
- def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
+ def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
@@ -141,8 +141,9 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
return list(self._group_to_ranks[group])
@staticmethod
- def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
- indices_at_axis: List[int]) -> List[Tuple[int, ...]]:
+ def get_coords_along_axis(
+ base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
+ ) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
Args:
@@ -155,13 +156,12 @@ def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
"""
coords_in_group = []
for idx in indices_at_axis:
- coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:])
+ coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group
- def create_group_along_axis(self,
- axis: int,
- indices_at_axis: Optional[List[int]] = None,
- backend: Optional[str] = None) -> ProcessGroup:
+ def create_group_along_axis(
+ self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ ) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
@@ -186,10 +186,9 @@ def create_group_along_axis(self,
target_group = group
return target_group
- def get_group_along_axis(self,
- axis: int,
- indices_at_axis: Optional[List[int]] = None,
- backend: Optional[str] = None) -> ProcessGroup:
+ def get_group_along_axis(
+ self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ ) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py
index eb6d5d05a008..ab57301bb910 100644
--- a/colossalai/context/__init__.py
+++ b/colossalai/context/__init__.py
@@ -3,6 +3,6 @@
# from .moe_context import MOE_CONTEXT
__all__ = [
- 'Config',
- 'ConfigException',
+ "Config",
+ "ConfigException",
]
diff --git a/colossalai/context/config.py b/colossalai/context/config.py
index 8903707708df..05a2e4bf044a 100644
--- a/colossalai/context/config.py
+++ b/colossalai/context/config.py
@@ -5,6 +5,7 @@
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path
+
from colossalai.logging import get_dist_logger
@@ -41,7 +42,7 @@ def _add_item(self, key, value):
self.__setattr__(key, value)
def update(self, config):
- assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
+ assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
for k, v in config.items():
self._add_item(k, v)
return self
@@ -66,11 +67,11 @@ def from_file(filename: str):
elif isinstance(filename, Path):
filepath = filename.absolute()
- assert filepath.exists(), f'{filename} is not found, please check your configuration path'
+ assert filepath.exists(), f"{filename} is not found, please check your configuration path"
# check extension
extension = filepath.suffix
- assert extension == '.py', 'only .py files are supported'
+ assert extension == ".py", "only .py files are supported"
# import the config as module
remove_path = False
@@ -86,13 +87,13 @@ def from_file(filename: str):
config = Config()
for k, v in module.__dict__.items():
- if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
+ if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)
logger = get_dist_logger()
- logger.debug('variables which starts with __, is a module or class declaration are omitted in config file')
+ logger.debug("variables which starts with __, is a module or class declaration are omitted in config file")
# remove module
del sys.modules[module_name]
diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py
index b6e3b52017b2..066dfc7222e1 100644
--- a/colossalai/context/moe_context.py
+++ b/colossalai/context/moe_context.py
@@ -9,14 +9,13 @@
def _check_sanity():
from colossalai.legacy.core import global_context as gpc
+
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
- raise NotImplementedError("Moe is not compatible with tensor or "
- "pipeline parallel at present.")
+ raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
class MoeParallelInfo:
- """Moe parallelism information, storing parallel sizes and groups.
- """
+ """Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
@@ -61,9 +60,11 @@ def setup(self, seed: int, use_kernel_optim: bool = True):
self.world_size = dist.get_world_size()
from colossalai.legacy.core import global_context as gpc
- self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
- assert self.world_size % self.max_ep_size == 0, \
- "Maximum expert parallel size must be a factor of the number of GPUs"
+
+ self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
+ assert (
+ self.world_size % self.max_ep_size == 0
+ ), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases
@@ -71,6 +72,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True):
self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed
+
moe_set_seed(seed)
self.has_setup = True
@@ -88,11 +90,13 @@ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
number of local experts, the MoeParallelInfo of the current ep_size
"""
- gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
- lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
+ gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
+ lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
- assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
- " is not a multiple of ep size or vice versa."
+ assert gt_flag or lt_flag, (
+ "Automatic experts placement dose not not support expert number"
+ " is not a multiple of ep size or vice versa."
+ )
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py
index 8ca335119d52..3088b0dffaac 100644
--- a/colossalai/context/singleton_meta.py
+++ b/colossalai/context/singleton_meta.py
@@ -16,6 +16,7 @@ def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
- assert len(args) == 0 and len(
- kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
+ assert (
+ len(args) == 0 and len(kwargs) == 0
+ ), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls]
diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py
index 689189998c3f..34a7d2526fda 100644
--- a/colossalai/device/__init__.py
+++ b/colossalai/device/__init__.py
@@ -1,4 +1,4 @@
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp
-__all__ = ['AlphaBetaProfiler', 'alpa_dp']
+__all__ = ["AlphaBetaProfiler", "alpa_dp"]
diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py
index f4e6cfffbcdf..88520b2a14d0 100644
--- a/colossalai/device/alpha_beta_profiler.py
+++ b/colossalai/device/alpha_beta_profiler.py
@@ -13,7 +13,7 @@
class AlphaBetaProfiler:
- '''
+ """
Profile alpha and beta value for a given device list.
Usage:
@@ -27,17 +27,19 @@ class AlphaBetaProfiler:
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
- '''
-
- def __init__(self,
- physical_devices: List[int],
- alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
- ctype: str = 'a',
- warmup: int = 5,
- repeat: int = 25,
- latency_iters: int = 5,
- homogeneous_tolerance: float = 0.1):
- '''
+ """
+
+ def __init__(
+ self,
+ physical_devices: List[int],
+ alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
+ ctype: str = "a",
+ warmup: int = 5,
+ repeat: int = 25,
+ latency_iters: int = 5,
+ homogeneous_tolerance: float = 0.1,
+ ):
+ """
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
@@ -45,7 +47,7 @@ def __init__(self,
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
- '''
+ """
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
@@ -123,7 +125,7 @@ def _profile(self, process_group, pg_handler, nbytes):
return (None, None)
def profile_latency(self, process_group, pg_handler):
- '''
+ """
This function is used to profile the latency of the given process group with a series of bytes.
Args:
@@ -132,7 +134,7 @@ def profile_latency(self, process_group, pg_handler):
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
- '''
+ """
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
@@ -148,26 +150,26 @@ def profile_latency(self, process_group, pg_handler):
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
- '''
+ """
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
- '''
+ """
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
- '''
+ """
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
- '''
+ """
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
- global_pg_handler = dist.new_group(self.physical_devices)
+ dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
@@ -208,7 +210,7 @@ def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
return alpha_beta_dict
def search_best_logical_mesh(self):
- '''
+ """
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
@@ -232,19 +234,19 @@ def search_best_logical_mesh(self):
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
- '''
+ """
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
- '''
+ """
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
- '''
+ """
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
@@ -254,7 +256,8 @@ def _detect_homogeneous_device(alpha_beta_dict):
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
- 1 - self.homogeneous_tolerance):
+ 1 - self.homogeneous_tolerance
+ ):
match_beta = beta_value
break
@@ -267,9 +270,9 @@ def _detect_homogeneous_device(alpha_beta_dict):
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
- '''
+ """
This function is used to check whether the homogeneous_group contains all physical devices.
- '''
+ """
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
@@ -277,9 +280,9 @@ def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
- '''
+ """
This function is used to construct the largest ring in the homogeneous_group for each rank.
- '''
+ """
# Construct the ring
ring = []
ranks_in_ring = []
@@ -300,7 +303,9 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
- rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
+ rank_to_append = (
+ process_group[0] if process_group[1] == check_rank else process_group[1]
+ )
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
@@ -314,7 +319,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
- balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
+ balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
@@ -348,7 +353,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
- '''
+ """
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
@@ -360,7 +365,7 @@ def extract_alpha_beta_for_device_mesh(self):
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
- '''
+ """
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py
index 4ab72dfe60f0..72d432701ada 100644
--- a/colossalai/device/calc_pipeline_strategy.py
+++ b/colossalai/device/calc_pipeline_strategy.py
@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
while i <= num_devices_per_host:
i *= 2
p += 1
- assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
- f"while now num_devices_per_host = {num_devices_per_host}")
+ assert pow(2, p) == num_devices_per_host, (
+ "Only supports the cases where num_devices_per_host is power of two, "
+ f"while now num_devices_per_host = {num_devices_per_host}"
+ )
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
return submesh_choices
-def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
- best_configs):
+def alpa_dp_impl(
+ num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
+):
"""Implementation of Alpa DP for pipeline strategy
- Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
+ Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
- Arguments:
- num_layers: K
- num_devices: N*M
- num_microbatches: B
- submesh_choices: List[(n_i,m_i)]
- compute_cost: t_intra
- """
+ Arguments:
+ num_layers: K
+ num_devices: N*M
+ num_microbatches: B
+ submesh_choices: List[(n_i,m_i)]
+ compute_cost: t_intra
+ """
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
- if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
+ if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
- next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
+ next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
- assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
+ assert current_s == 0 and current_layer == num_layers and current_devices == 0
return total_cost, res
-def alpa_dp(num_layers,
- num_devices,
- num_microbatches,
- submesh_choices,
- num_autosharding_configs,
- compute_cost,
- gap=1e-6):
+def alpa_dp(
+ num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
+):
"""Alpa auto stage dynamic programming.
- Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
+ Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
- assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
- num_autosharding_configs), "Cost shape wrong."
+ assert np.shape(compute_cost) == (
+ num_layers,
+ num_layers,
+ len(submesh_choices),
+ num_autosharding_configs,
+ ), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
break
if max_stage_cost - last_max_stage_cost < gap:
continue
- cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
- max_stage_cost, best_configs)
+ cost, solution = alpa_dp_impl(
+ num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
+ )
if cost < best_cost:
best_cost = cost
best_solution = solution
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index f41af1161be1..72f199203a9d 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -40,14 +40,16 @@ class DeviceMesh:
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
- def __init__(self,
- physical_mesh_id: torch.Tensor,
- mesh_shape: torch.Size = None,
- logical_mesh_id: torch.Tensor = None,
- mesh_alpha: List[float] = None,
- mesh_beta: List[float] = None,
- init_process_group: bool = False,
- device: str = 'cuda'):
+ def __init__(
+ self,
+ physical_mesh_id: torch.Tensor,
+ mesh_shape: torch.Size = None,
+ logical_mesh_id: torch.Tensor = None,
+ mesh_alpha: List[float] = None,
+ mesh_beta: List[float] = None,
+ init_process_group: bool = False,
+ device: str = "cuda",
+ ):
# ============================
# Physical & Logical Mesh IDs
# ============================
@@ -57,9 +59,10 @@ def __init__(self,
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
- assert mesh_shape is None or logical_mesh_id is None, \
- "Only one of mesh_shape and logical_mesh_id can be specified." \
+ assert mesh_shape is None or logical_mesh_id is None, (
+ "Only one of mesh_shape and logical_mesh_id can be specified."
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
+ )
if logical_mesh_id is None:
self._mesh_shape = mesh_shape
@@ -71,12 +74,15 @@ def __init__(self,
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
- assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
- "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
- assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
- "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
- assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
- "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
+ assert torch.equal(
+ torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
+ ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
+ assert (
+ torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
+ ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
+ assert (
+ torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
+ ), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model
@@ -92,8 +98,9 @@ def __init__(self,
self.mesh_beta = tuple(mesh_beta)
# ensure the alpha and beta have the same shape
- assert len(self.mesh_alpha) == len(self.mesh_beta), \
- "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
+ assert len(self.mesh_alpha) == len(
+ self.mesh_beta
+ ), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
@@ -109,8 +116,9 @@ def __init__(self,
# : [ , , , ...]
# }
self._global_to_local_rank_mapping = dict()
- self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
- tensor=self.logical_mesh_id)
+ self._init_global_to_logical_rank_mapping(
+ mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
+ )
# create process group
self._process_group_dict = {}
@@ -194,8 +202,9 @@ def _get_device_by_backend(process_group):
device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same
- assert all([device == device_list[0] for device in device_list]), \
- "All devices should be the same, please check your input process groups are created with the same distributed backend."
+ assert all(
+ [device == device_list[0] for device in device_list]
+ ), "All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id
# as we only get the process group associated with the current process,
@@ -270,7 +279,7 @@ def __deepcopy__(self, memo) -> "DeviceMesh":
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
- if k != '_process_group_dict':
+ if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
@@ -278,10 +287,9 @@ def __deepcopy__(self, memo) -> "DeviceMesh":
setattr(result, k, v)
return result
- def _init_global_to_logical_rank_mapping(self,
- mapping: Dict,
- tensor: torch.Tensor,
- index_list: List[int] = []) -> Dict[int, List[int]]:
+ def _init_global_to_logical_rank_mapping(
+ self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
+ ) -> Dict[int, List[int]]:
"""
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
@@ -311,15 +319,19 @@ def _init_global_to_logical_rank_mapping(self,
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self):
- '''
+ """
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
- '''
+ """
# sanity check
- assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
- assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
+ assert (
+ dist.is_initialized
+ ), "The torch.distributed should be initialized before calling init_logical_process_group"
+ assert (
+ not self._is_initialized
+ ), "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
@@ -389,7 +401,7 @@ def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[i
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
- '''
+ """
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
@@ -414,7 +426,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
- '''
+ """
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
@@ -437,7 +449,6 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
-
# if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
@@ -478,29 +489,37 @@ def flatten(self):
flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
- return DeviceMesh(self._physical_mesh_id,
- tuple(flatten_mesh_shape),
- mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
- mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
- init_process_group=self._init_process_group)
+ return DeviceMesh(
+ self._physical_mesh_id,
+ tuple(flatten_mesh_shape),
+ mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
+ mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
+ init_process_group=self._init_process_group,
+ )
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
- 0.1)
+ return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
- 0.01)
+ return (
+ self.mesh_alpha[mesh_dim]
+ + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ + 0.01
+ )
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
- 0.001)
+ return (
+ self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
+ )
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
- (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
+ return (
+ self.mesh_alpha[mesh_dim]
+ + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ + 0.001
+ )
diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py
index 0444a4816273..4d40d5badfd0 100644
--- a/colossalai/fx/_compatibility.py
+++ b/colossalai/fx/_compatibility.py
@@ -2,16 +2,14 @@
import torch
-TORCH_MAJOR = int(torch.__version__.split('.')[0])
-TORCH_MINOR = int(torch.__version__.split('.')[1])
+TORCH_MAJOR = int(torch.__version__.split(".")[0])
+TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
META_COMPATIBILITY = False
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
- from . import _meta_regist_12
META_COMPATIBILITY = True
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
- from . import _meta_regist_13
META_COMPATIBILITY = True
elif TORCH_MAJOR == 2:
META_COMPATIBILITY = True
@@ -36,7 +34,7 @@ def decorator(func):
else:
def wrapper(*args, **kwargs):
- raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
+ raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}")
return wrapper
diff --git a/colossalai/fx/_meta_regist_12.py b/colossalai/fx/_meta_regist_12.py
index 52e8d63ae543..63f88682e85a 100644
--- a/colossalai/fx/_meta_regist_12.py
+++ b/colossalai/fx/_meta_regist_12.py
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Union
import torch
from torch.utils._pytree import tree_map
@@ -16,13 +16,11 @@
def register_meta(op, register_dispatcher=True):
-
def wrapper(f):
-
def add_func(op):
meta_table[op] = f
if register_dispatcher:
- name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
+ name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -48,7 +46,6 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
-
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -125,7 +122,8 @@ def calc_conv_nd_return_shape(
kernel_size[i],
stride[i],
output_padding_list[i],
- ))
+ )
+ )
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -159,22 +157,42 @@ def pick_memory_format():
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
- out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
+ out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
-def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
- padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
- *extra_args):
+def meta_conv_1(
+ input_tensor: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ is_transposed: bool,
+ output_padding: List[int],
+ groups: int,
+ *extra_args,
+):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
-def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
- padding, dilation, transposed, output_padding, groups, output_mask):
- return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
+def meta_conv_backward(
+ grad_output: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias_sizes,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ output_mask,
+):
+ return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta")
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -208,7 +226,6 @@ def meta_cuda_rnn(
batch_sizes,
dropout_state,
):
-
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -224,8 +241,11 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
- out_shape = ([mini_batch, seq_length, out_size *
- num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
+ out_shape = (
+ [mini_batch, seq_length, out_size * num_directions]
+ if batch_first
+ else [seq_length, mini_batch, out_size * num_directions]
+ )
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -242,18 +262,20 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
-def meta_cudnn_rnn_backward(input: torch.Tensor,
- weight: torch.Tensor,
- weight_stride0: int,
- hx: torch.Tensor,
- cx: Optional[torch.Tensor] = None,
- *args,
- **kwargs):
+def meta_cudnn_rnn_backward(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_stride0: int,
+ hx: torch.Tensor,
+ cx: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+):
print(input, weight, hx, cx)
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_hx = torch.empty_like(hx)
- grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
+ grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta")
return grad_input, grad_weight, grad_hx, grad_cx
@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((n_input), device='meta')
- running_var = torch.empty((n_input), device='meta')
+ running_mean = torch.empty((n_input), device="meta")
+ running_var = torch.empty((n_input), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
-def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
- save_invstd, train, eps, output_mask):
+def meta_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ train,
+ eps,
+ output_mask,
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((n_input), device='meta')
- running_var = torch.empty((n_input), device='meta')
- reserve = torch.empty((0), dtype=torch.uint8, device='meta')
+ running_mean = torch.empty((n_input), device="meta")
+ running_var = torch.empty((n_input), device="meta")
+ reserve = torch.empty((0), dtype=torch.uint8, device="meta")
return output, running_mean, running_var, reserve
@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
-def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, eps, reserve):
+def meta_cudnn_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ eps,
+ reserve,
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((bs, n_input, 1), device='meta')
- running_var = torch.empty((bs, n_input, 1), device='meta')
+ running_mean = torch.empty((bs, n_input, 1), device="meta")
+ running_var = torch.empty((bs, n_input, 1), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
-def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
- grad_input_mask):
+def meta_ln_backward(
+ dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
- assert index.dtype in [torch.long, torch.int8, torch.bool],\
- "tensors used as indices must be long, byte or bool tensors"
+ assert index.dtype in [
+ torch.long,
+ torch.int8,
+ torch.bool,
+ ], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
- assert index.shape[j] == self.shape[
- k +
- j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
+ assert (
+ index.shape[j] == self.shape[k + j]
+ ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices):
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
-def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
- scale_grad_by_freq):
- return torch.empty((num_weights, grad_output.size(-1)),
- dtype=grad_output.dtype,
- device=grad_output.device,
- layout=grad_output.layout)
+def meta_embedding_dense_backward(
+ grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
+):
+ return torch.empty(
+ (num_weights, grad_output.size(-1)),
+ dtype=grad_output.dtype,
+ device=grad_output.device,
+ layout=grad_output.layout,
+ )
# ============================== Dropout ===========================================
diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py
index 33b164800262..dfb5754d71c1 100644
--- a/colossalai/fx/codegen/activation_checkpoint_codegen.py
+++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Any, Dict, Iterable, List, Tuple
import torch
@@ -18,6 +18,7 @@
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import (
@@ -32,12 +33,13 @@
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
- __all__ = ['ActivationCheckpointCodeGen']
+ __all__ = ["ActivationCheckpointCodeGen"]
else:
- __all__ = ['python_code_with_activation_checkpoint']
+ __all__ = ["python_code_with_activation_checkpoint"]
def _gen_saved_tensors_hooks():
@@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]):
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
- ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_checkpoint' in node.meta:
- act_ckpt_label = node.meta['activation_checkpoint']
+ if "activation_checkpoint" in node.meta:
+ act_ckpt_label = node.meta["activation_checkpoint"]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
- elif current_region is not None and not 'activation_checkpoint' in node.meta:
+ elif current_region is not None and not "activation_checkpoint" in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
@@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
- act_offload_label = node.meta['activation_offload']
+ if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable):
+ act_offload_label = node.meta["activation_offload"]
if current_region == None:
current_region = act_offload_label
@@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
"""
Generate the checkpoint function call code text
"""
- outputs = ', '.join(output_vars)
- inputs = ', '.join(input_vars)
- return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
+ outputs = ", ".join(output_vars)
+ inputs = ", ".join(input_vars)
+ return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
@@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
Returns:
bool
"""
- if 'activation_checkpoint' in node.meta:
- if isinstance(node.meta['activation_checkpoint'], list):
- return node.meta['activation_checkpoint'][check_idx] == None
+ if "activation_checkpoint" in node.meta:
+ if isinstance(node.meta["activation_checkpoint"], list):
+ return node.meta["activation_checkpoint"][check_idx] == None
else:
return False
else:
@@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_checkpoint' in node.meta:
- if isinstance(node.meta['activation_checkpoint'], int):
- act_ckpt_label = node.meta['activation_checkpoint']
+ if "activation_checkpoint" in node.meta:
+ if isinstance(node.meta["activation_checkpoint"], int):
+ act_ckpt_label = node.meta["activation_checkpoint"]
else:
- act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
+ act_ckpt_label = node.meta["activation_checkpoint"][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
return ckpt_regions
-def emit_ckpt_func(body,
- ckpt_func,
- node_list: List[Node],
- emit_node_func,
- delete_unused_value_func,
- level=0,
- in_ckpt=False):
+def emit_ckpt_func(
+ body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False
+):
"""Emit ckpt function in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
@@ -321,17 +318,17 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
- if isinstance(node_list[0].meta['activation_checkpoint'], int):
- label = node_list[0].meta['activation_checkpoint']
+ if isinstance(node_list[0].meta["activation_checkpoint"], int):
+ label = node_list[0].meta["activation_checkpoint"]
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
- activation_offload = node_list[0].meta.get('activation_offload', False)
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
+ activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
@@ -340,12 +337,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
- label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
+ label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
- if level + 1 < len(node_list[0].meta['activation_checkpoint']):
+ if level + 1 < len(node_list[0].meta["activation_checkpoint"]):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -358,38 +355,45 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
- emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
- delete_unused_value_func, level + 1, True)
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(
+ ckpt_func,
+ ckpt_func_buffer,
+ ckpt_node_list,
+ emit_node_func,
+ delete_unused_value_func,
+ level + 1,
+ True,
+ )
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
- activation_offload = node_list[0].meta.get('activation_offload', False)
- usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
+ activation_offload = node_list[0].meta.get("activation_offload", False)
+ usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
- activation_offload = node_list[0].meta.get('activation_offload', False)
- usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
+ activation_offload = node_list[0].meta.get("activation_offload", False)
+ usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
@@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
- offload_node_list = node_list[start:end + 1]
+ offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
@@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process ckpt_regions
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
@@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
if within_offload_region:
emit_node_func(node, body)
- body[-1] = ' ' + body[-1]
+ body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
@@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
- ckpt_node_list = node_list[start:end + 1]
+ ckpt_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs)
output_vars.append(outputs)
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
- offload_node_list = node_list[start:end + 1]
+ offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
@@ -527,7 +531,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
within_ckpt_region = True
if idx in offload_starts:
@@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
elif within_offload_region:
emit_node_func(node, body)
- body[-1] = ' ' + body[-1]
+ body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
@@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
- return_statement = f' {return_statement}\n\n'
+ return_statement = f" {return_statement}\n\n"
ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
- if 'activation_offload' in node_list[start_node_idx].meta:
- activation_offload = node_list[start_node_idx].meta['activation_offload']
+ if "activation_offload" in node_list[start_node_idx].meta:
+ activation_offload = node_list[start_node_idx].meta["activation_offload"]
else:
activation_offload = False
@@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
- if 'activation_checkpoint' in user.meta:
- if user.meta['activation_checkpoint'] == label:
+ if "activation_checkpoint" in user.meta:
+ if user.meta["activation_checkpoint"] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
@@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
- usage += '\n'
+ usage += "\n"
body.append(usage)
within_ckpt_region = False
@@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen):
-
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
@@ -629,7 +632,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -637,7 +640,7 @@ def add_global(name_hint: str, obj: Any):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -662,16 +665,16 @@ def add_global(name_hint: str, obj: Any):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
- if hasattr(o, '__args__'):
+ if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
@@ -690,19 +693,18 @@ def type_repr(o: Any):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
- if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+ if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
- args_s = ', '.join(_get_repr(a) for a in args)
- kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+ args_s = ", ".join(_get_repr(a) for a in args)
+ kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
- return f'{args_s}, {kwargs_s}'
+ return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
@@ -728,90 +730,101 @@ def delete_unused_values(user: Node, body):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
- body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
- f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
+ if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
@@ -820,13 +833,13 @@ def emit_node(node: Node, body):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
@@ -837,11 +850,11 @@ def emit_node(node: Node, body):
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
- prologue = ''.join(ckpt_func) + prologue
+ prologue = "".join(ckpt_func) + prologue
prologue = prologue
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
@@ -861,7 +874,7 @@ def python_code_with_activation_checkpoint(self, root_module: str, namespace: _N
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -869,7 +882,7 @@ def add_global(name_hint: str, obj: Any):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -894,12 +907,12 @@ def add_global(name_hint: str, obj: Any):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
# This is a generic type, e.g. typing.List[torch.Tensor]
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
@@ -934,84 +947,94 @@ def delete_unused_values(user: Node, body):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
if self._pytree_info is None:
- body.append(f'return {repr(node.args[0])}')
+ body.append(f"return {repr(node.args[0])}")
else:
- body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)')
+ body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)")
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
+ if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
@@ -1020,33 +1043,34 @@ def emit_node(node: Node, body):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if self._pytree_info is not None:
orig_args = self._pytree_info.orig_args
- has_orig_self = (orig_args[0] == 'self')
+ has_orig_self = orig_args[0] == "self"
if has_orig_self:
- free_vars.insert(0, 'self')
- if len(free_vars) > 0: # pytree has placeholders in it
+ free_vars.insert(0, "self")
+ if len(free_vars) > 0: # pytree has placeholders in it
body.insert(
0,
- f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n")
+ f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n",
+ )
else:
orig_args = free_vars
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
- ckpt_func = ''.join(ckpt_func)
+ ckpt_func = "".join(ckpt_func)
# If the original function didn't have self as its first argument, we
# would have added it.
- if len(orig_args) == 0 or orig_args[0] != 'self':
- orig_args.insert(0, 'self')
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ if len(orig_args) == 0 or orig_args[0] != "self":
+ orig_args.insert(0, "self")
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py
index ebb9975f27db..8429a9607f7a 100644
--- a/colossalai/fx/graph_module.py
+++ b/colossalai/fx/graph_module.py
@@ -1,32 +1,35 @@
import os
import warnings
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Type, Union
+from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
try:
- from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
- from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
+ from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen
+ from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
+
COLOGM = True
except:
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
+
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
-
- def __init__(self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: Graph,
- class_name: str = 'GraphModule',
- ckpt_codegen: bool = True):
+ def __init__(
+ self,
+ root: Union[torch.nn.Module, Dict[str, Any]],
+ graph: Graph,
+ class_name: str = "GraphModule",
+ ckpt_codegen: bool = True,
+ ):
if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
@@ -60,7 +63,7 @@ def recompile(self) -> PythonCode:
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module='self')
+ python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -83,8 +86,8 @@ def recompile(self) -> PythonCode:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if '_wrapped_call' not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
+ if "_wrapped_call" not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -108,7 +111,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / 'state_dict.pt')
+ torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -125,7 +128,13 @@ def __init__(self):
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [
- nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
+ nn.Linear,
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ nn.BatchNorm1d,
+ nn.BatchNorm2d,
+ nn.BatchNorm3d,
]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
@@ -136,10 +145,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
- module_file = folder / f'{module_name}.pt'
+ module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
- module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
+ module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -156,19 +165,20 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / 'module.py'
+ module_file = folder / "module.py"
module_file.write_text(model_str)
- init_file = folder / '__init__.py'
- init_file.write_text('from .module import *')
+ init_file = folder / "__init__.py"
+ init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
- warnings.warn("Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}")
+ warnings.warn(
+ "Was not able to save the following children modules as reprs -"
+ f"saved as pickled files instead: {blobified_modules}"
+ )
else:
class ColoGraphModule(GraphModule):
-
- def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
+ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"):
super().__init__(root, graph, class_name)
diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py
index 245ba5d776da..99c8faaa0cc6 100644
--- a/colossalai/fx/passes/adding_split_node_pass.py
+++ b/colossalai/fx/passes/adding_split_node_pass.py
@@ -1,8 +1,6 @@
import numpy as np
import torch
import tqdm
-from torch.fx import symbolic_trace
-from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop = 0
block_nodes = []
for node in gm.graph.nodes:
- if 'block_split' in node.name:
+ if "block_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node):
- block_node = gm.graph.create_node('call_function', block_split)
- setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
- setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
+ block_node = gm.graph.create_node("call_function", block_split)
+ setattr(block_node, "fwd_flop", accumulate_fwd_flop)
+ setattr(block_node, "bwd_flop", accumulate_bwd_flop)
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes.append(block_node)
@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
- if (node.op, node.target) == ('call_function', block_split):
+ if (node.op, node.target) == ("call_function", block_split):
gm.graph.erase_node(node)
@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
- for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
- for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
+ for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
+ for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops)
@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0
- for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
- for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
- for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
+ for s in tqdm.tqdm(
+ range(1, num_stages + 1), desc="stage", position=2, leave=False
+ ): # pylint: disable=too-many-nested-blocks
+ for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
+ for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost
- if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
+ if stage_cost <= max_compute_cost and new_cost < f[s, i]:
f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k
@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
best_cost = np.inf
best_solution = None
last_max_compute_cost = 0.0
- gap = 1e6 # temporary magic number, unit: flops
+ gap = 1e6 # temporary magic number, unit: flops
for max_compute_cost in tqdm.tqdm(max_compute_costs):
# Pruning to reduce search space.
@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if max_compute_cost - last_max_compute_cost < gap:
continue
- cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
- max_compute_cost)
+ cost, solution = do_dp_split_gpipe_impl(
+ len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
+ )
if cost < best_cost:
best_cost = cost
@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
-def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
- assert mode in ['node', 'block']
+def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
+ assert mode in ["node", "block"]
# nodes or blocks will be used in partition.
node_list = []
- if mode == 'node':
+ if mode == "node":
for node in gm.graph.nodes:
node_list.append(node)
- elif mode == 'block':
+ elif mode == "block":
node_list = construct_blocks(gm, limit=block_limit)
else:
pass
@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
- for (_, next_start_node) in best_solution:
+ for _, next_start_node in best_solution:
if pp_size <= 1:
break
node = node_list[next_start_node]
with gm.graph.inserting_before(node):
- split_node = gm.graph.create_node('call_function', pipe_split)
+ split_node = gm.graph.create_node("call_function", pipe_split)
pp_size -= 1
# remove block node if possible
- if mode == 'block':
+ if mode == "block":
remove_blocks(gm)
gm.recompile()
@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
- if 'tensor_meta' not in check_node.meta:
+ if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0
@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if 'pipe_split' in node.name:
+ if "pipe_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
@@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
- if node.next.op == 'output':
+ if node.next.op == "output":
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
- if node.next.op == 'output':
+ if node.next.op == "output":
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if node.op == 'placeholder':
+ if node.op == "placeholder":
continue
elif node_counter == 0:
node_counter += 1
@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
- if 'tensor_meta' not in check_node.meta:
+ if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_element_size = 0
@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if 'pipe_split' in node.name:
+ if "pipe_split" in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def split_callback(n: torch.fx.Node):
nonlocal part_idx
- if (n.op, n.target) == ('call_function', pipe_split):
+ if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -355,7 +356,7 @@ def split_callback(n: torch.fx.Node):
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
- if (node.op, node.target) == ('call_function', pipe_split):
+ if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py
index 81ac64205528..5440a4eadbbf 100644
--- a/colossalai/fx/passes/concrete_info_prop.py
+++ b/colossalai/fx/passes/concrete_info_prop.py
@@ -1,5 +1,5 @@
from dataclasses import asdict
-from typing import Any, Dict, List, NamedTuple, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.fx
@@ -85,10 +85,10 @@ def run_node(self, n: Node) -> Any:
self._is_proped = True
result, meta_info = super().run_node(n)
- n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
+ n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
- setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
- n.meta['type'] = type(result)
+ setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
+ n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -98,7 +98,7 @@ def run_node(self, n: Node) -> Any:
# Main Node running APIs
@compatibility(is_backward_compatible=True)
- def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -119,7 +119,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -138,7 +138,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -157,7 +157,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -175,7 +175,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -197,7 +197,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -228,7 +228,7 @@ def propagate(self, *args):
"""
return self.run(*args)
- def summary(self, unit: str = 'MB') -> str:
+ def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -238,9 +238,11 @@ def summary(self, unit: str = 'MB') -> str:
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -249,10 +251,10 @@ def summary(self, unit: str = 'MB') -> str:
def mem_repr(mem: int) -> str:
unit_divisor_map = {
- 'kb': 1024,
- 'mb': 1024**2,
- 'gb': 1024**3,
- 'tb': 1024**4,
+ "kb": 1024,
+ "mb": 1024**2,
+ "gb": 1024**3,
+ "tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -261,30 +263,32 @@ def time_repr(time: float):
for node in self.module.graph.nodes:
node: Node
- node_summaries.append([
- node.op,
- str(node),
- time_repr(node.meta['fwd_time']),
- time_repr(node.meta['bwd_time']),
- node.meta['save_fwd_in'],
- mem_repr(node.meta['fwd_mem_out']),
- mem_repr(node.meta['fwd_mem_tmp']),
- mem_repr(node.meta['bwd_mem_out']),
- mem_repr(node.meta['bwd_mem_tmp']),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ time_repr(node.meta["fwd_time"]),
+ time_repr(node.meta["bwd_time"]),
+ node.meta["save_fwd_in"],
+ mem_repr(node.meta["fwd_mem_out"]),
+ mem_repr(node.meta["fwd_mem_tmp"]),
+ mem_repr(node.meta["bwd_mem_out"]),
+ mem_repr(node.meta["bwd_mem_tmp"]),
+ ]
+ )
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Forward time',
- 'Backward time',
- 'SAVE_FWD_IN',
- 'FWD_OUT',
- 'FWD_TMP',
- 'BWD_OUT',
- 'BWD_TMP',
+ "Op type",
+ "Op",
+ "Forward time",
+ "Backward time",
+ "SAVE_FWD_IN",
+ "FWD_OUT",
+ "FWD_TMP",
+ "BWD_OUT",
+ "BWD_TMP",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
index 4571bd93a790..3d032a27db63 100644
--- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
+++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
@@ -1,14 +1,11 @@
-import torch
-from typing import List
-from torch.fx import symbolic_trace
-from torch.fx.node import Node
-from colossalai.fx.passes.split_module import split_module
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
-from copy import deepcopy
+from typing import List
+
+import torch
+
+from colossalai.tensor.shape_consistency import ShapeConsistencyManager
+from colossalai.tensor.sharding_spec import ShardingSpec
def apply(*args, **kwargs):
@@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
- setattr(node, 'best_strategy', strategies_vector[strategy_index])
- setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
+ setattr(node, "best_strategy", strategies_vector[strategy_index])
+ setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
# apply the sharding spec of parameters
for node in nodes:
- if node.op == 'call_module':
+ if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
- setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
+ setattr(target_module.weight, "sharding_spec", origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
apply(target_module.weight, target_weight_sharding_spec)
@@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
# add above dicts into graph
for node in nodes:
- if node.op != 'placeholder':
+ if node.op != "placeholder":
with mod_graph.inserting_before(node):
- input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
- origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
+ input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
+ origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
node_to_index_dict = {}
index = 0
for node in nodes:
- if node.target == 'sharding_spec_convert_dict':
+ if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
- if node.target == 'origin_node_sharding_spec_dict':
+ if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
# add shape consistency apply function into graph
for node in nodes:
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
with mod_graph.inserting_after(node):
- origin_spec_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(origin_dict_node, node_to_index_dict[node]))
+ origin_spec_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node])
+ )
with mod_graph.inserting_after(origin_spec_node):
- set_sharding_spec_node = mod_graph.create_node('call_function',
- builtins.setattr,
- args=(node, 'sharding_spec', origin_spec_node))
+ set_sharding_spec_node = mod_graph.create_node(
+ "call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node)
+ )
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
- input_specs_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(input_dict_node, node_to_index_dict[node]))
+ input_specs_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node])
+ )
with mod_graph.inserting_before(user_node):
- sharding_spec_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(input_specs_node, node_index))
+ sharding_spec_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(input_specs_node, node_index)
+ )
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
+ shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node))
return gm
diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py
index ab203dfd7440..1720aa58da2b 100644
--- a/colossalai/fx/passes/meta_info_prop.py
+++ b/colossalai/fx/passes/meta_info_prop.py
@@ -109,13 +109,13 @@ def extract_tensor_meta(obj):
return TensorMetadata(None, None, False, None, 0, False)
tensor_meta = tree_map(extract_tensor_meta, result)
- n.meta['tensor_meta'] = tensor_meta
- n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
+ n.meta["tensor_meta"] = tensor_meta
+ n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
- setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
- setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
- setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
- n.meta['type'] = type(result)
+ setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
+ setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
+ setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
+ n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -125,7 +125,7 @@ def extract_tensor_meta(obj):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
- def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -146,7 +146,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -165,7 +165,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -184,7 +184,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -202,7 +202,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -224,7 +224,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -240,7 +240,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str,
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
- if hasattr(args[0], '_tensor'):
+ if hasattr(args[0], "_tensor"):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)
@@ -257,7 +257,7 @@ def propagate(self, *args):
"""
return super().run(*args)
- def summary(self, unit: str = 'MB') -> str:
+ def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -267,9 +267,11 @@ def summary(self, unit: str = 'MB') -> str:
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -278,10 +280,10 @@ def summary(self, unit: str = 'MB') -> str:
def mem_repr(mem: int) -> str:
unit_divisor_map = {
- 'kb': 1024,
- 'mb': 1024**2,
- 'gb': 1024**3,
- 'tb': 1024**4,
+ "kb": 1024,
+ "mb": 1024**2,
+ "gb": 1024**3,
+ "tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -292,35 +294,37 @@ def flops_repr(flop: int) -> str:
for node in self.module.graph.nodes:
node: Node
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
- node_summaries.append([
- node.op,
- str(node),
- flops_repr(node.meta['fwd_flop']),
- flops_repr(node.meta['bwd_flop']),
- mem_repr(accumulate_size),
- mem_repr(calculate_fwd_in(node)),
- mem_repr(calculate_fwd_out(node)),
- mem_repr(calculate_fwd_tmp(node)),
- mem_repr(node.meta['bwd_mem_out']),
- mem_repr(node.meta['bwd_mem_tmp']),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ flops_repr(node.meta["fwd_flop"]),
+ flops_repr(node.meta["bwd_flop"]),
+ mem_repr(accumulate_size),
+ mem_repr(calculate_fwd_in(node)),
+ mem_repr(calculate_fwd_out(node)),
+ mem_repr(calculate_fwd_tmp(node)),
+ mem_repr(node.meta["bwd_mem_out"]),
+ mem_repr(node.meta["bwd_mem_tmp"]),
+ ]
+ )
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Forward FLOPs',
- 'Backward FLOPs',
- 'Accumulated Memory',
- 'FWD_IN',
- 'FWD_OUT',
- 'FWD_TMP',
- 'BWD_OUT',
- 'BWD_TMP',
+ "Op type",
+ "Op",
+ "Forward FLOPs",
+ "Backward FLOPs",
+ "Accumulated Memory",
+ "FWD_IN",
+ "FWD_OUT",
+ "FWD_TMP",
+ "BWD_OUT",
+ "BWD_TMP",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
+
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
- gm.to('cpu')
+ gm.to("cpu")
del interp
return gm
diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py
index efdd34a01fe0..73379f73689c 100644
--- a/colossalai/fx/passes/passes_for_gpt2_test.py
+++ b/colossalai/fx/passes/passes_for_gpt2_test.py
@@ -5,7 +5,6 @@
from packaging import version
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
-from torch.fx.node import Node
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
from colossalai.fx.passes.meta_info_prop import TensorMetadata
@@ -13,9 +12,9 @@
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
- '''
+ """
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
- '''
+ """
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
- '''
+ """
This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future.
- '''
+ """
part_idx = 0
def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes:
- if node.op == 'placeholder':
+ if node.op == "placeholder":
if not len(node.users):
gm.graph.erase_node(node)
gm.recompile()
return gm
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
- '''
+ """
This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
- '''
+ """
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__
output_args.extend([n.name for n in node.args[0]])
@@ -114,7 +113,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders):
continue
for node in gm.graph.nodes:
- if node.op == 'placeholder':
+ if node.op == "placeholder":
new_placeholder_list.append(node.name)
if output_type is not None:
gm.graph.output(output_type(output_args))
@@ -125,7 +124,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders):
def split_callback(n: torch.fx.Node):
nonlocal part_idx
- if (n.op, n.target) == ('call_function', pipe_split):
+ if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -134,7 +133,7 @@ def split_callback(n: torch.fx.Node):
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
- if (node.op, node.target) == ('call_function', pipe_split):
+ if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
@@ -200,13 +199,12 @@ def _gen_all_ancestors_set(node):
_gen_all_ancestors_set(node)
for n in list(all_ancestors):
- if n.op != 'placeholder' and n._fx_partition > partition_name:
+ if n.op != "placeholder" and n._fx_partition > partition_name:
n._fx_partition = partition_name
- def record_cross_partition_use(def_node: torch.fx.node.Node,
- use_node: Optional[torch.fx.node.Node]): # noqa: B950
- def_partition_name = getattr(def_node, '_fx_partition', None)
- use_partition_name = getattr(use_node, '_fx_partition', None)
+ def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def_partition_name = getattr(def_node, "_fx_partition", None)
+ use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
# if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
@@ -237,7 +235,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
if node.op in ["placeholder"]:
continue
- if node.op == 'output':
+ if node.op == "output":
# partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name
@@ -252,12 +250,12 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
- origin_partition_name = getattr(node, '_fx_partition', None)
+ origin_partition_name = getattr(node, "_fx_partition", None)
if origin_partition_name is None:
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
- torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
+ torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -287,7 +285,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
- if hasattr(node, '_fx_partition'):
+ if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -295,26 +293,24 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
- if node.op not in ['call_module', 'get_attr']:
+ if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
- target_atoms = node.target.split('.')
+ target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
- raise RuntimeError(f'Operator target {node.target} not found!')
+ raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
- target = '_'.join(target_atoms)
+ target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
- new_node = partition.graph.create_node(op=node.op,
- target=target,
- args=gathered_args,
- kwargs=gathered_kwargs,
- name=node.name)
+ new_node = partition.graph.create_node(
+ op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
+ )
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -323,14 +319,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
- if node.op == 'placeholder':
- if version.parse(torch.__version__) < version.parse('1.11.0'):
+ if node.op == "placeholder":
+ if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
- base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
- type_expr=node.type,
- default_value=default_value)
+ base_mod_env[node.name] = base_mod_graph.placeholder(
+ node.name, type_expr=node.type, default_value=default_value
+ )
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -344,13 +340,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
- output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
+ output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
- submod_name = f'submod_{partition_name}'
- base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
- partition.graph) # noqa: B950
+ submod_name = f"submod_{partition_name}"
+ base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
+ partition.targets, partition.graph
+ ) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -358,14 +355,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
- base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
+ base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
- if node.op == 'output':
- base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
+ if node.op == "output":
+ base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py
index ccbab0c38a29..be8261f2a3f4 100644
--- a/colossalai/fx/passes/shard_1d_pass.py
+++ b/colossalai/fx/passes/shard_1d_pass.py
@@ -9,8 +9,19 @@
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
- torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
- operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
+ torch.add,
+ operator.add,
+ torch.abs,
+ torch.cos,
+ torch.exp,
+ torch.mul,
+ operator.mul,
+ operator.floordiv,
+ operator.truediv,
+ operator.neg,
+ torch.multiply,
+ torch.nn.functional.relu,
+ torch.nn.functional.dropout,
]
@@ -72,7 +83,7 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# traverse the graph to look for consecutive linear layers
is_linear_module = False
- if node.op == 'call_module':
+ if node.op == "call_module":
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
@@ -82,31 +93,31 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
- annotation_record['row'] = module
+ annotation_record["row"] = module
for shard_type, module in annotation_record.items():
# add row sharding spec
- if shard_type == 'row':
+ if shard_type == "row":
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
- setattr(module.weight, 'pg', process_group)
- setattr(module.weight, 'dist_spec', dist_spec)
- setattr(module.weight, 'comp_spec', comp_spec)
- elif shard_type == 'col':
+ setattr(module.weight, "pg", process_group)
+ setattr(module.weight, "dist_spec", dist_spec)
+ setattr(module.weight, "comp_spec", comp_spec)
+ elif shard_type == "col":
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
- setattr(module.weight, 'pg', process_group)
- setattr(module.weight, 'dist_spec', weight_dist_spec)
- setattr(module.weight, 'comp_spec', weight_comp_spec)
+ setattr(module.weight, "pg", process_group)
+ setattr(module.weight, "dist_spec", weight_dist_spec)
+ setattr(module.weight, "comp_spec", weight_comp_spec)
if module.bias is not None:
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
- setattr(module.bias, 'pg', process_group)
- setattr(module.bias, 'dist_spec', bias_dist_spec)
- setattr(module.bias, 'comp_spec', bias_comp_spec)
+ setattr(module.bias, "pg", process_group)
+ setattr(module.bias, "dist_spec", bias_dist_spec)
+ setattr(module.bias, "comp_spec", bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
@@ -114,16 +125,16 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
- annotation_record['col'] = module
+ annotation_record["col"] = module
if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
- if node.op == 'call_module':
+ if node.op == "call_module":
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
- elif node.op == 'call_function' or node.op == 'call_method':
+ elif node.op == "call_function" or node.op == "call_method":
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:
diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py
index 61ed037ab7a1..67a2432595d6 100644
--- a/colossalai/fx/passes/split_module.py
+++ b/colossalai/fx/passes/split_module.py
@@ -25,12 +25,14 @@ def __init__(self, name: str):
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
- return f"name: {self.name},\n" \
- f" nodes: {self.node_names},\n" \
- f" inputs: {self.inputs},\n" \
- f" outputs: {self.outputs},\n" \
- f" partitions dependent on: {self.partitions_dependent_on},\n" \
+ return (
+ f"name: {self.name},\n"
+ f" nodes: {self.node_names},\n"
+ f" inputs: {self.inputs},\n"
+ f" outputs: {self.outputs},\n"
+ f" partitions dependent on: {self.partitions_dependent_on},\n"
f" partition dependents: {self.partition_dependents}"
+ )
# Creates subgraphs out of main graph
@@ -117,10 +119,9 @@ def forward(self, x, y):
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
- def record_cross_partition_use(def_node: torch.fx.node.Node,
- use_node: Optional[torch.fx.node.Node]): # noqa: B950
- def_partition_name = getattr(def_node, '_fx_partition', None)
- use_partition_name = getattr(use_node, '_fx_partition', None)
+ def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def_partition_name = getattr(def_node, "_fx_partition", None)
+ use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
@@ -134,7 +135,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
- def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
@@ -161,7 +162,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
if node.op in ["placeholder"]:
continue
- if node.op == 'output':
+ if node.op == "output":
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
@@ -178,7 +179,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
- torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
+ torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -208,7 +209,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
- if hasattr(node, '_fx_partition'):
+ if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -216,25 +217,24 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
- if node.op not in ['call_module', 'get_attr']:
+ if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
- target_atoms = node.target.split('.')
+ target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
- raise RuntimeError(f'Operator target {node.target} not found!')
+ raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
- target = '_'.join(target_atoms)
+ target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
- new_node = partition.graph.create_node(op=node.op,
- target=target,
- args=gathered_args,
- kwargs=gathered_kwargs)
+ new_node = partition.graph.create_node(
+ op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
+ )
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -243,14 +243,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
- if node.op == 'placeholder':
- if version.parse(torch.__version__) < version.parse('1.11.0'):
+ if node.op == "placeholder":
+ if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
- base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
- type_expr=node.type,
- default_value=default_value)
+ base_mod_env[node.name] = base_mod_graph.placeholder(
+ node.target, type_expr=node.type, default_value=default_value
+ )
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -264,13 +264,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
- output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
+ output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
- submod_name = f'submod_{partition_name}'
- base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
- partition.graph) # noqa: B950
+ submod_name = f"submod_{partition_name}"
+ base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
+ partition.targets, partition.graph
+ ) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -278,15 +279,15 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
- base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
+ base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
- if node.op == 'output':
- base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
+ if node.op == "output":
+ base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
for partition_name in sorted_partitions:
partition = partitions[partition_name]
diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py
index bb4f3cd6a490..c51f49a30e8a 100644
--- a/colossalai/fx/passes/utils.py
+++ b/colossalai/fx/passes/utils.py
@@ -1,7 +1,9 @@
-import torch
from typing import Dict
-from torch.fx.node import Node, map_arg
+
+import torch
from torch.fx.graph import Graph
+from torch.fx.node import Node, map_arg
+
def get_comm_size(prev_partition, next_partition):
"""
@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
- comm_size += n.meta['tensor_meta'].numel
+ comm_size += n.meta["tensor_meta"].numel
visited_nodes.add(n)
return comm_size
@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
"""
input_nodes: Dict[Node, None] = {}
for node in graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = []
for node in input_nodes.keys():
- if node.op == 'placeholder':
+ if node.op == "placeholder":
placeholder_nodes.append(node)
for node in placeholder_nodes:
input_nodes.pop(node)
@@ -60,13 +62,13 @@ def get_top(graph: Graph):
"""
top_node_list = set()
for node in graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
continue
is_top = False
def _get_top(node):
nonlocal is_top
- if node.op == 'placeholder':
+ if node.op == "placeholder":
is_top = True
map_arg(node.args, lambda n: _get_top(n))
@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
def get_all_consumers(graph: Graph, node: Node):
"""
Given a graph and a node of this graph, return all consumers of the node.
-
+
Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
"""
@@ -120,7 +122,7 @@ def forward(self, x):
for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level)
-
+
Output:
graph():
%x : [#users=2] = placeholder[target=x]
@@ -148,7 +150,7 @@ def forward(self, x):
while nodes_to_process:
new_process_list = []
for node in nodes_to_process:
- if node.op == 'output':
+ if node.op == "output":
continue
node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node))
@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
torch.nn.Module: the module associated with the given node
"""
- assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
- assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
+ assert (
+ node.graph.owning_module is not None
+ ), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
+ assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
module = node.graph.owning_module.get_submodule(node.target)
return module
-
diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py
index 8bcbde0eb23b..89dd2b3df617 100644
--- a/colossalai/fx/profiler/__init__.py
+++ b/colossalai/fx/profiler/__init__.py
@@ -12,7 +12,16 @@
)
from .tensor import MetaTensor
else:
- from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
+ from .experimental import (
+ meta_profiler_function,
+ meta_profiler_module,
+ profile_function,
+ profile_method,
+ profile_module,
+ calculate_fwd_in,
+ calculate_fwd_tmp,
+ calculate_fwd_out,
+ )
from .dataflow import GraphInfo
from .memory_utils import activation_size, is_inplace, parameter_size
diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py
index 5763a46dc83f..fad9bb272bff 100644
--- a/colossalai/fx/profiler/constants.py
+++ b/colossalai/fx/profiler/constants.py
@@ -1,6 +1,6 @@
import torch
-__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
+__all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"]
aten = torch.ops.aten
diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py
index a5e8880322b8..05f9b50ce575 100644
--- a/colossalai/fx/profiler/dataflow.py
+++ b/colossalai/fx/profiler/dataflow.py
@@ -1,6 +1,5 @@
from dataclasses import dataclass, field
from enum import Enum
-from functools import partial
from typing import Dict, List
from torch.fx import Graph, Node
@@ -69,8 +68,8 @@ class GraphInfo:
def is_phase(n: Node, phase: Phase) -> bool:
- assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
- return n.meta['phase'] == phase
+ assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!"
+ return n.meta["phase"] == phase
@compatibility(is_backward_compatible=False)
@@ -103,9 +102,9 @@ def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
- peak_mem += activation_size(k.meta['saved_tensor'])
- if v <= float('-inf') and is_phase(k, Phase.FORWARD):
- peak_mem -= activation_size(k.meta['saved_tensor'])
+ peak_mem += activation_size(k.meta["saved_tensor"])
+ if v <= float("-inf") and is_phase(k, Phase.FORWARD):
+ peak_mem -= activation_size(k.meta["saved_tensor"])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
@@ -123,19 +122,19 @@ def _peak_memory(deps: Dict[Node, int]):
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
- graph_info.fwd_in += n.meta['saved_tensor']
+ graph_info.fwd_in += n.meta["saved_tensor"]
if is_phase(n, Phase.FORWARD):
- graph_info.fwd_tmp += n.meta['saved_tensor']
+ graph_info.fwd_tmp += n.meta["saved_tensor"]
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
- graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
+ graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
- deps[input_n] = float('-inf')
+ deps[input_n] = float("-inf")
return graph_info
diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py
index 57ff3fd91299..02758e7643af 100644
--- a/colossalai/fx/profiler/experimental/constants.py
+++ b/colossalai/fx/profiler/experimental/constants.py
@@ -2,7 +2,7 @@
import torch
-__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
+__all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"]
# TODO fill out the inplace ops
INPLACE_OPS = [
@@ -20,25 +20,25 @@
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
- 'transpose',
- 'permute',
+ "transpose",
+ "permute",
# TODO: reshape may return a copy of the data if the data is not contiguous
- 'reshape',
- 'dim',
- 'flatten',
- 'size',
- 'view',
- 'unsqueeze',
- 'to',
- 'type',
- 'flatten',
+ "reshape",
+ "dim",
+ "flatten",
+ "size",
+ "view",
+ "unsqueeze",
+ "to",
+ "type",
+ "flatten",
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
- 'chunk',
- 'contiguous',
- 'expand',
- 'mean',
- 'split',
+ "chunk",
+ "contiguous",
+ "expand",
+ "mean",
+ "split",
]
diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py
index 5c545260e72b..d890fdb66fc2 100644
--- a/colossalai/fx/profiler/experimental/profiler.py
+++ b/colossalai/fx/profiler/experimental/profiler.py
@@ -9,7 +9,7 @@
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
-__all__ = ['profile_function', 'profile_module', 'profile_method']
+__all__ = ["profile_function", "profile_module", "profile_method"]
# this is for compatibility use
@@ -42,6 +42,7 @@ class GraphInfo:
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
+
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
@@ -50,8 +51,7 @@ class GraphInfo:
bwd_mem_out: int = 0
-CALL_FUNCTION_MSG = \
-"""
+CALL_FUNCTION_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
macs = ...
return flops, macs
"""
-CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
-CALL_MODULE_MSG = \
-"""
+CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
+CALL_MODULE_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True)
-def profile_function(target: 'Target') -> Callable:
+def profile_function(target: "Target") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
- target.__name__), CALL_FUNCTION_MSG.format(target)
+ target.__name__
+ ), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
- if target not in INPLACE_OPS and not kwargs.get('inplace', False):
+ if target not in INPLACE_OPS and not kwargs.get("inplace", False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
@@ -112,7 +112,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@compatibility(is_backward_compatible=True)
-def profile_method(target: 'Target') -> Callable:
+def profile_method(target: "Target") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@@ -126,11 +126,12 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
self_obj, *args_tail = args
# execute the method and return the result
- assert isinstance(target, str), f'{target} instance is not str.'
+ assert isinstance(target, str), f"{target} instance is not str."
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
- target, INPLACE_METHOD, NON_INPLACE_METHOD)
+ target, INPLACE_METHOD, NON_INPLACE_METHOD
+ )
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
@@ -161,7 +162,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
- if getattr(module, 'inplace', False):
+ if getattr(module, "inplace", False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
index a43aef063e19..c518ec28da41 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused.
diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
index 8d1c8a8c6877..f1b9bb97c6c6 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
@@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other):
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
-@meta_profiler_function.register('add') # for built-in op +
-@meta_profiler_function.register('iadd') # for built-in op +=
-@meta_profiler_function.register('eq') # for built-in op =
-@meta_profiler_function.register('sub') # for built-in op -
-@meta_profiler_function.register('isub') # for built-in op -=
-@meta_profiler_function.register('mul') # for built-in op *
-@meta_profiler_function.register('imul') # for built-in op *=
-@meta_profiler_function.register('floordiv') # for built-in op //
-@meta_profiler_function.register('ifloordiv') # for built-in op //=
+@meta_profiler_function.register("add") # for built-in op +
+@meta_profiler_function.register("iadd") # for built-in op +=
+@meta_profiler_function.register("eq") # for built-in op =
+@meta_profiler_function.register("sub") # for built-in op -
+@meta_profiler_function.register("isub") # for built-in op -=
+@meta_profiler_function.register("mul") # for built-in op *
+@meta_profiler_function.register("imul") # for built-in op *=
+@meta_profiler_function.register("floordiv") # for built-in op //
+@meta_profiler_function.register("ifloordiv") # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
@meta_profiler_function.register(torch.matmul)
-@meta_profiler_function.register('matmul') # for built-in op @
+@meta_profiler_function.register("matmul") # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
@@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T
@meta_profiler_function.register(torch.var_mean)
-def torch_var_mean(input: torch.Tensor,
- dim: Union[int, Tuple[int, ...]],
- unbiased: Optional[bool] = True,
- keepdim: Optional[bool] = False,
- *,
- out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
- assert out is None, 'saving to out is not supported yet'
+def torch_var_mean(
+ input: torch.Tensor,
+ dim: Union[int, Tuple[int, ...]],
+ unbiased: Optional[bool] = True,
+ keepdim: Optional[bool] = False,
+ *,
+ out: Optional[torch.Tensor] = None,
+) -> Tuple[int, int]:
+ assert out is None, "saving to out is not supported yet"
flops = input.numel() * 3
macs = 0
return flops, macs
diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py
index d6e43d781b8b..1d362015fc8b 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/embedding.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py
@@ -1,5 +1,7 @@
-import torch
from typing import Optional
+
+import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py
index 01fe4c871370..ecc578d61b91 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/linear.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py
index c4ea508d70f8..2ad029eda039 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/normalization.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py
@@ -1,5 +1,7 @@
from typing import List, Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_function
@@ -21,11 +23,13 @@ def torch_nn_func_instancenorm(
@meta_profiler_function.register(torch.nn.functional.group_norm)
-def torch_nn_func_groupnorm(input: torch.Tensor,
- num_groups: int,
- weight: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- eps: float = 1e-5) -> Tuple[int, int]:
+def torch_nn_func_groupnorm(
+ input: torch.Tensor,
+ num_groups: int,
+ weight: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+ eps: float = 1e-5,
+) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py
index a639f5ee83c1..c91deab906d4 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/pooling.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py
@@ -1,5 +1,7 @@
-from typing import Tuple, Union
+from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
index 1e8561206ba0..58c9889ad98e 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
@@ -1,6 +1,6 @@
import operator
from typing import Any, Tuple
-import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
index abdd7ad565ba..67e90fb69acd 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
@@ -1,7 +1,9 @@
-from functools import reduce
import operator
+from functools import reduce
from typing import Any, Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_function
@@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
@meta_profiler_function.register(torch.max)
-def torch_max(input: torch.Tensor,
- dim: int = None,
- keepdim: bool = False,
- *,
- out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
+def torch_max(
+ input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None
+) -> Tuple[int, int]:
macs = 0
- assert out is None, 'assigning value to out is not supported yet'
+ assert out is None, "assigning value to out is not supported yet"
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
index 2ebf514ad269..ae065e0c7c17 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py
index 8daf74b232bf..dfaee75e0432 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/attention.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py
@@ -1,19 +1,23 @@
from typing import Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_module
# TODO: This is hard to compute memory cost
@meta_profiler_module.register(torch.nn.MultiheadAttention)
-def torch_nn_msa(self: torch.nn.MultiheadAttention,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- key_padding_mask: Optional[torch.Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[torch.Tensor] = None,
- average_attn_weights: bool = True) -> Tuple[int, int]:
- if getattr(self, 'batch_first', False):
+def torch_nn_msa(
+ self: torch.nn.MultiheadAttention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_padding_mask: Optional[torch.Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+ average_attn_weights: bool = True,
+) -> Tuple[int, int]:
+ if getattr(self, "batch_first", False):
batch_size = query.shape[0]
len_idx = 1
else:
@@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
flops += qlen * qdim
# Initial projections
- flops += 2 * ((qlen * qdim * qdim) # QW
- + (klen * kdim * kdim) # KW
- + (vlen * vdim * vdim) # VW
- )
+ flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW
- macs += ((qlen * qdim * qdim) # QW
- + (klen * kdim * kdim) # KW
- + (vlen * vdim * vdim) # VW
- )
+ macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW
if self.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim
@@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
v_head_dim = vdim // num_heads
head_flops = (
- 2 * (qlen * klen * qk_head_dim) # QK^T
- + (qlen * klen) # softmax
- + 2 * (qlen * klen * v_head_dim) # AV
+ 2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV
)
- head_macs = ((qlen * klen * qk_head_dim) # QK^T
- + 2 * (qlen * klen * v_head_dim) # AV
- )
+ head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV
flops += num_heads * head_flops
macs += num_heads * head_flops
diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
index a4c15b91e611..90e494c77f5b 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/convolution.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
@@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
- l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
+ l_out = math.floor(
+ (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
@@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
+ h_out = math.floor(
+ (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
@@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
- d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
- w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
- (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
+ d_out = math.floor(
+ (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ h_out = math.floor(
+ (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
@@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
- l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
@@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(
operator.mul, input.shape
- ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
+ ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
@@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
@@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
- d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
- (self.kernel_size[2] - 1) + self.output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[2]
+ - 2 * self.padding[2]
+ + self.dilation[2] * (self.kernel_size[2] - 1)
+ + self.output_padding[2]
+ + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py
index 417e0ed46863..7361239eb1bd 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/dropout.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py
index e1ffb6f244d2..71fed3196c13 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/linear.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
index 49e5e6fa5384..5a64e44947b7 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/normalization.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
@@ -16,8 +16,12 @@
@meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d)
-def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
- torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
+def torch_nn_normalize(
+ self: Union[
+ torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
+ ],
+ input: torch.Tensor,
+) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None
if self.training:
@@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch
try:
import apex
+
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py
index e429ac3eea28..b3b630b2dee9 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/pooling.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py
index 6e733d6da915..8a4c828dbd27 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/rnn.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py
@@ -1,12 +1,15 @@
-from functools import reduce
import operator
+from functools import reduce
+from typing import Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_module
-from typing import Optional, Tuple, Union
-def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
- w_hh: torch.Tensor) -> Tuple[int, int]:
+def _rnn_flops(
+ flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor
+) -> Tuple[int, int]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state
@@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
flops = 0
macs = 0
for i in range(self.num_layers):
- w_ih = self.__getattr__('weight_ih_l' + str(i))
- w_hh = self.__getattr__('weight_hh_l' + str(i))
+ w_ih = self.__getattr__("weight_ih_l" + str(i))
+ w_hh = self.__getattr__("weight_hh_l" + str(i))
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
- b_ih = self.__getattr__('bias_ih_l' + str(i))
- b_hh = self.__getattr__('bias_hh_l' + str(i))
+ b_ih = self.__getattr__("bias_ih_l" + str(i))
+ b_hh = self.__getattr__("bias_hh_l" + str(i))
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= reduce(operator.mul, input.shape[:2])
macs *= reduce(operator.mul, input.shape[:2])
@@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
- w_ih = self.__getattr__('weight_ih_l')
- w_hh = self.__getattr__('weight_hh_l')
+ w_ih = self.__getattr__("weight_ih_l")
+ w_hh = self.__getattr__("weight_hh_l")
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
- b_ih = self.__getattr__('bias_ih_l')
- b_hh = self.__getattr__('bias_hh_l')
+ b_ih = self.__getattr__("bias_ih_l")
+ b_hh = self.__getattr__("bias_hh_l")
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= input.shape[0]
macs *= input.shape[0]
diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
index d3aed874eb10..06be25246a71 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
@@ -1,7 +1,8 @@
-import operator
+from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
-from typing import Optional, Tuple, Union
@meta_profiler_module.register(torch.nn.Flatten)
diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py
index 7d73bce321e4..d47129cd2978 100644
--- a/colossalai/fx/profiler/experimental/registry.py
+++ b/colossalai/fx/profiler/experimental/registry.py
@@ -1,11 +1,9 @@
class ProfilerRegistry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
self.store[source] = func
return func
@@ -21,5 +19,5 @@ def has(self, source):
return source in self.store
-meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
-meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')
+meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile")
+meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile")
diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py
index 1e53ed0bf8ec..90e8c3b7cfe4 100644
--- a/colossalai/fx/profiler/experimental/shard_utils.py
+++ b/colossalai/fx/profiler/experimental/shard_utils.py
@@ -1,8 +1,6 @@
# for PyTorch 1.11 compatibility uses
-from typing import Dict, List, Tuple, Union
-import torch
-from torch.fx import GraphModule, Node
+from torch.fx import Node
from ..._compatibility import compatibility
@@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool:
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
- return n.meta['save_fwd_in']
+ return n.meta["save_fwd_in"]
@compatibility(is_backward_compatible=True)
@@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int:
Returns:
fwd_out (int): the result of `fwd_out`
"""
- return n.meta['fwd_mem_out']
+ return n.meta["fwd_mem_out"]
diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py
index 6ccbcb01cdc1..e8eb5f25cb6c 100644
--- a/colossalai/fx/profiler/memory_utils.py
+++ b/colossalai/fx/profiler/memory_utils.py
@@ -1,11 +1,11 @@
from typing import Dict, List, Tuple, Union
import torch
-from torch.fx import GraphModule, Node
+from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta
-__all__ = ['activation_size', 'parameter_size', 'is_inplace']
+__all__ = ["activation_size", "parameter_size", "is_inplace"]
@compatibility(is_backward_compatible=True)
@@ -63,6 +63,7 @@ def is_inplace(n: Node):
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
+
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py
index ba090a2ec51b..8fae0f2ecb45 100644
--- a/colossalai/fx/profiler/opcount.py
+++ b/colossalai/fx/profiler/opcount.py
@@ -173,8 +173,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
- has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
- 'shape') else inputs[affine_arg_index]
+ has_affine = (
+ inputs[affine_arg_index].shape is not None
+ if hasattr(inputs[affine_arg_index], "shape")
+ else inputs[affine_arg_index]
+ )
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
- return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
+ return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -218,15 +221,16 @@ def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
def zero_flop_jit(*args):
"""
- Count flops for zero flop layers.
+ Count flops for zero flop layers.
"""
return 0
-if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
- torch.__version__) < version.parse('2.0.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse(
+ "2.0.0"
+):
flop_mapping = {
- # gemm, gemv and dot
+ # gemm, gemv and dot
aten.mm.default: matmul_flop_jit,
aten.mv.default: matmul_flop_jit,
aten.dot.default: matmul_flop_jit,
@@ -234,13 +238,11 @@ def zero_flop_jit(*args):
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
aten.baddbmm.default: baddbmm_flop_jit,
-
- # convolution
+ # convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
-
- # normalization
+ # normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
@@ -249,8 +251,7 @@ def zero_flop_jit(*args):
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_group_norm.default: norm_flop_counter(2, 0),
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
-
- # pooling
+ # pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
@@ -275,7 +276,7 @@ def zero_flop_jit(*args):
}
elementwise_flop_aten = [
- # basic op
+ # basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -296,8 +297,7 @@ def zero_flop_jit(*args):
aten.exp.default,
aten.sin.default,
aten.cos.default,
-
- # activation op
+ # activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -320,8 +320,7 @@ def zero_flop_jit(*args):
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
-
- # dropout
+ # dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
]
@@ -362,7 +361,7 @@ def zero_flop_jit(*args):
aten.zero_.default,
aten.zeros_like.default,
aten.fill_.Scalar,
- aten.stack.default
+ aten.stack.default,
] # yapf: disable
for op in zero_flop_aten:
diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py
index c87cd4321d31..97e70db6290e 100644
--- a/colossalai/fx/profiler/profiler.py
+++ b/colossalai/fx/profiler/profiler.py
@@ -15,7 +15,7 @@
from .opcount import flop_mapping
from .tensor import MetaTensor
-__all__ = ['profile_function', 'profile_module', 'profile_method']
+__all__ = ["profile_function", "profile_module", "profile_method"]
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
@@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
-
_node: Node = None
def __repr__(self):
@@ -186,24 +185,24 @@ def __repr__(self):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
- node = subgraph.create_node('call_function', func, args_node, kwargs_node)
+ node = subgraph.create_node("call_function", func, args_node, kwargs_node)
out = super().__torch_dispatch__(func, types, args, kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
- node.meta['phase'] = phase
+ node.meta["phase"] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
- node.meta['phase'] = Phase.PLACEHOLDER
+ node.meta["phase"] = Phase.PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation
- node.meta['saved_tensor'] = []
+ node.meta["saved_tensor"] = []
if phase == Phase.BACKWARD:
- node.meta['saved_tensor'] = normalize_tuple(out)
+ node.meta["saved_tensor"] = normalize_tuple(out)
def wrap(x):
if isinstance(x, MetaTensor):
@@ -219,11 +218,14 @@ def wrap(x):
x = FlopTensor(x)
if is_autogradable(x):
x.requires_grad_(True)
- x._node = subgraph.create_node('placeholder',
- 'placeholder', (subgraph._root,),
- name=subgraph._graph_namespace.create_name('input', x._tensor))
- x._node.meta['phase'] = Phase.PLACEHOLDER
- x._node.meta['saved_tensor'] = []
+ x._node = subgraph.create_node(
+ "placeholder",
+ "placeholder",
+ (subgraph._root,),
+ name=subgraph._graph_namespace.create_name("input", x._tensor),
+ )
+ x._node.meta["phase"] = Phase.PLACEHOLDER
+ x._node.meta["saved_tensor"] = []
return x
# Basically, we need to detach the args and kwargs from the outer graph.
@@ -235,7 +237,7 @@ def pack(x):
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
- x._node.meta['saved_tensor'] += [tensor]
+ x._node.meta["saved_tensor"] += [tensor]
if not do_not_cache:
cache.add(x._tensor.data_ptr())
return x
@@ -284,7 +286,7 @@ def unwrap(x):
@compatibility(is_backward_compatible=True)
-def profile_function(target: 'Target', device: str = 'meta') -> Callable:
+def profile_function(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
-
# find the grad for parameter in args and kwargs
param_size = 0
@@ -316,18 +317,18 @@ def get_param_size(x):
# still run the profiling but discard some results regarding `target`
global do_not_cache
- inplace = kwargs.get('inplace', False)
+ inplace = kwargs.get("inplace", False)
if target in OUTPUT_SAVED_OPS:
do_not_cache = True
if inplace:
do_not_cache = True
- kwargs['inplace'] = False
- if device == 'meta':
+ kwargs["inplace"] = False
+ if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
- kwargs['inplace'] = True
+ kwargs["inplace"] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
@@ -341,7 +342,7 @@ def get_param_size(x):
@compatibility(is_backward_compatible=True)
-def profile_method(target: 'Target', device: str = 'meta') -> Callable:
+def profile_method(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
- assert isinstance(target, str), f'{target} instance is not str.'
- if device == 'meta':
+ assert isinstance(target, str), f"{target} instance is not str."
+ if device == "meta":
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
@@ -360,7 +361,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@compatibility(is_backward_compatible=True)
-def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
+def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
@@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
-
# calculate parameter size
param_size = parameter_size(module)
@@ -384,13 +384,13 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# still run the profiling but discard some results regarding `module`.
global do_not_cache
- inplace = getattr(module, 'inplace', False)
+ inplace = getattr(module, "inplace", False)
if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True
if inplace:
do_not_cache = True
module.inplace = False
- if device == 'meta':
+ if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py
index 34feefb4336a..75b7c814f05f 100644
--- a/colossalai/fx/profiler/shard_utils.py
+++ b/colossalai/fx/profiler/shard_utils.py
@@ -59,9 +59,9 @@ def forward(self, input_2):
Returns:
bool: Whether the node is a ReLU-like node
"""
- if n.op == 'call_function':
+ if n.op == "call_function":
return n.target in OUTPUT_SAVED_OPS
- elif n.op == 'call_module':
+ elif n.op == "call_module":
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
return False
diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py
index 2ee5e5c47750..7c14b48bdaa1 100644
--- a/colossalai/fx/profiler/tensor.py
+++ b/colossalai/fx/profiler/tensor.py
@@ -1,13 +1,13 @@
import uuid
import torch
-from torch.types import _bool, _device, _dtype
-from torch.utils._pytree import tree_flatten, tree_map
+from torch.types import _device
+from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
-__all__ = ['MetaTensor']
+__all__ = ["MetaTensor"]
def set_data_ptr(x):
@@ -43,12 +43,13 @@ def __new__(cls, elem, fake_device=None):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
- device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
- requires_grad=elem.requires_grad) # deceive the frontend for aten selections
+ device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
+ requires_grad=elem.requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
set_data_ptr(r._tensor)
return r
@@ -69,15 +70,15 @@ def unwrap(x):
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
- if 'device' in kwargs:
- fake_device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ fake_device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
@@ -93,7 +94,7 @@ def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
@@ -120,18 +121,18 @@ def replace(x):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
- return 'meta'
+ return "meta"
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
- if self.device.type == 'cpu':
+ if self.device.type == "cpu":
return self.to(*args, **kwargs)
- return self.to(*args, device='cpu', **kwargs)
+ return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
- return self.to(device='cuda:0', non_blocking=non_blocking)
+ return self.to(device="cuda:0", non_blocking=non_blocking)
diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py
index 7317072c6298..887832223fd6 100644
--- a/colossalai/fx/proxy.py
+++ b/colossalai/fx/proxy.py
@@ -1,12 +1,11 @@
-import operator
-from typing import Any, List, Union
+from typing import Any
import torch
-from torch.fx.proxy import Attribute, Proxy
+from torch.fx.proxy import Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_function
-__all__ = ['ColoProxy']
+__all__ = ["ColoProxy"]
class ColoProxy(Proxy):
@@ -39,11 +38,12 @@ def has_meta_data(self):
return self._meta_data is not None
def _assert_meta_data_is_tensor(self):
- assert torch.is_tensor(
- self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
+ assert (
+ torch.is_tensor(self._meta_data) and self._meta_data.is_meta
+ ), f"Meta data is not a meta tensor for {self.node.name}"
def _assert_has_meta_data(self):
- assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
+ assert self._meta_data is not None, f"Meta data is not set for {self.node.name}"
def __len__(self):
self._assert_has_meta_data()
@@ -62,7 +62,6 @@ def __bool__(self):
return self.meta_data
def __getattr__(self, k):
-
return ColoAttribute(self, k)
def __contains__(self, key):
@@ -92,7 +91,6 @@ def _convert(val):
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py
index 1c5abb81d271..63a7bab654d5 100644
--- a/colossalai/fx/tracer/_meta_trace.py
+++ b/colossalai/fx/tracer/_meta_trace.py
@@ -39,7 +39,7 @@ class MetaProxy(torch.Tensor):
_tensor: torch.Tensor
_node: Node
- __slots__ = ['_tensor', '_node']
+ __slots__ = ["_tensor", "_node"]
@staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
@@ -51,22 +51,22 @@ def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
dtype=tensor.dtype,
layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device,
- requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
+ requires_grad=tensor.requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
- name = 'input'
- r._node = graph.create_node('placeholder',
- 'placeholder', (graph._root,),
- name=namespace.create_name(name, tensor))
+ name = "input"
+ r._node = graph.create_node(
+ "placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
+ )
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaProxy):
@@ -75,21 +75,21 @@ def unwrap(x):
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
def get_node(x):
- if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
- x = MetaProxy(x, placeholder=True, name='weight')
- return x if not hasattr(x, '_node') else x._node
+ if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
+ x = MetaProxy(x, placeholder=True, name="weight")
+ return x if not hasattr(x, "_node") else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
- node = graph.create_node('call_function', func, args_node, kwargs_node)
+ node = graph.create_node("call_function", func, args_node, kwargs_node)
- if 'device' in kwargs:
- fake_device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ fake_device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@@ -103,9 +103,12 @@ def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
- x = x.to(torch.device('meta'))
- return MetaProxy(
- x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
+ x = x.to(torch.device("meta"))
+ return (
+ MetaProxy(x, fake_device=fake_device)
+ if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
+ else x
+ )
def set_node(x):
x._node = node
@@ -125,9 +128,12 @@ def wrap(x):
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
- grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
- tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
- torch.autograd.backward(tensor,
- MetaProxy(grad, fake_device=tensor.device, placeholder=True),
- retain_graph=True)
+ grad = (
+ torch.empty_like(tensor._tensor, device=torch.device("meta"))
+ if isinstance(tensor, MetaProxy)
+ else torch.empty_like(tensor, device=torch.device("meta"))
+ )
+ torch.autograd.backward(
+ tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
+ )
return graph
diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py
index e160497a7444..9cf1961d45ff 100644
--- a/colossalai/fx/tracer/_tracer_utils.py
+++ b/colossalai/fx/tracer/_tracer_utils.py
@@ -2,10 +2,10 @@
import torch
-from ..proxy import ColoAttribute, ColoProxy
-from .meta_patch import meta_patched_function, meta_patched_module
+from ..proxy import ColoProxy
+from .meta_patch import meta_patched_function
-__all__ = ['is_element_in_list', 'extract_meta']
+__all__ = ["is_element_in_list", "extract_meta"]
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
@@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
def extract_meta(*args, **kwargs):
-
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
index 859a19bf6241..84c09109877e 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
@@ -1,7 +1,4 @@
-import operator
-
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@@ -10,13 +7,12 @@
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
kwargs = {}
- if 'beta' in self.kwargs:
- kwargs['beta'] = self.kwargs['beta']
- if 'alpha' in self.kwargs:
- kwargs['alpha'] = self.kwargs['alpha']
+ if "beta" in self.kwargs:
+ kwargs["beta"] = self.kwargs["beta"]
+ if "alpha" in self.kwargs:
+ kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
@@ -25,7 +21,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
@@ -35,10 +31,10 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy):
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
- '''
+ """
This method is used to sum the input_proxy through the sum_dims.
- '''
- node_kind = 'call_function'
+ """
+ node_kind = "call_function"
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
@@ -55,15 +51,15 @@ def generate(self):
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
- if 'beta' in kwargs:
- beta = kwargs['beta']
+ if "beta" in kwargs:
+ beta = kwargs["beta"]
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
- if 'alpha' in kwargs:
- alpha = kwargs['alpha']
+ if "alpha" in kwargs:
+ alpha = kwargs["alpha"]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
index fe7d8d07aac9..d087b2913005 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
@@ -1,7 +1,4 @@
-import operator
-
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@@ -10,17 +7,16 @@
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
kwargs = {}
- if 'beta' in self.kwargs:
- kwargs['beta'] = self.kwargs['beta']
- if 'alpha' in self.kwargs:
- kwargs['alpha'] = self.kwargs['alpha']
+ if "beta" in self.kwargs:
+ kwargs["beta"] = self.kwargs["beta"]
+ if "alpha" in self.kwargs:
+ kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def transpose_other_operand_for_linear(self, other_proxy):
- '''
+ """
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
@@ -30,8 +26,8 @@ def transpose_other_operand_for_linear(self, other_proxy):
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
- '''
- node_kind = 'call_function'
+ """
+ node_kind = "call_function"
node_target = torch.transpose
node_args = (other_proxy, 0, 1)
node_kwargs = {}
@@ -43,14 +39,14 @@ def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
kwargs = self.extract_kwargs_from_origin_func()
- if 'beta' in kwargs:
- beta = kwargs['beta']
+ if "beta" in kwargs:
+ beta = kwargs["beta"]
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
- if 'alpha' in kwargs:
- alpha = kwargs['alpha']
+ if "alpha" in kwargs:
+ alpha = kwargs["alpha"]
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
index 8a3786332c08..42178b7b786e 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
@@ -29,7 +29,6 @@ def extract_kwargs_from_origin_func(self):
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
- pass
@abstractmethod
def generate(self):
@@ -50,7 +49,6 @@ def generate(self):
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
- pass
def create_mul_node(self, input_proxy, coefficent):
"""
@@ -59,7 +57,7 @@ def create_mul_node(self, input_proxy, coefficent):
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = operator.mul
node_args = (
input_proxy,
@@ -82,7 +80,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.nn.functional.linear
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
@@ -96,7 +94,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
- bias_add_node_kind = 'call_function'
+ bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
index e11ec0a364f1..ed060a350739 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
@@ -1,6 +1,3 @@
-import operator
-
-import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
@@ -9,17 +6,16 @@
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
- assert 'bias' in self.kwargs
+ assert "bias" in self.kwargs
kwargs = {}
- if 'bias' in self.kwargs:
- kwargs['bias'] = self.kwargs['bias']
+ if "bias" in self.kwargs:
+ kwargs["bias"] = self.kwargs["bias"]
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
- bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
+ bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"])
return bias_addition_proxy
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
index 591485fdb1ca..19c0e21d7c17 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
@@ -27,8 +27,8 @@ def _create_weight_proxy(self):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
- weight_node_kind = 'get_attr'
- weight_node_target = self.target + '.weight'
+ weight_node_kind = "get_attr"
+ weight_node_target = self.target + ".weight"
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy
@@ -39,8 +39,8 @@ def _create_bias_proxy(self):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
- bias_node_kind = 'get_attr'
- bias_node_target = self.target + '.bias'
+ bias_node_kind = "get_attr"
+ bias_node_target = self.target + ".bias"
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy
@@ -54,14 +54,13 @@ def extract_kwargs_from_mod(self):
considered during module initializing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
- pass
def create_non_bias_func_proxy(self, input_proxy=None):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
if input_proxy is None:
input_proxy = self.args[0]
@@ -75,7 +74,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
- bias_add_node_kind = 'call_function'
+ bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
@@ -100,7 +99,6 @@ def generate(self):
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
- pass
module_to_func_dict = {
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
index 4b6c82a74f57..812a141c1eab 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
@@ -1,6 +1,5 @@
import torch
-import torch.nn.functional as F
-from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
+from torch.nn.modules.utils import _pair, _single, _triple
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@@ -10,17 +9,16 @@
@bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule):
-
def extract_kwargs_from_mod(self):
root = self.tracer.root
conv_module = root.get_submodule(self.target)
- kwarg_attributes = ['groups', 'dilation', 'stride']
+ kwarg_attributes = ["groups", "dilation", "stride"]
non_bias_kwargs = {}
for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros":
- #TODO: non zeros mode requires some extra processing for input
+ # TODO: non zeros mode requires some extra processing for input
conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d":
padding_element = _single(0)
@@ -28,9 +26,9 @@ def extract_kwargs_from_mod(self):
padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0)
- non_bias_kwargs['padding'] = padding_element
+ non_bias_kwargs["padding"] = padding_element
else:
- non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
+ non_bias_kwargs["padding"] = getattr(conv_module, "padding")
return non_bias_kwargs
@@ -41,11 +39,12 @@ def create_bias_reshape_proxy(self, dimensions):
"""
bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1
- bias_reshape_node_kind = 'call_method'
- bias_reshape_node_target = 'view'
+ bias_reshape_node_kind = "call_method"
+ bias_reshape_node_target = "view"
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
- bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
- bias_reshape_node_args, {})
+ bias_reshape_proxy = self.tracer.create_proxy(
+ bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {}
+ )
return bias_reshape_proxy
def generate(self):
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
index f6f7b6ddab40..b397f009846c 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
@@ -1,5 +1,4 @@
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@@ -7,7 +6,6 @@
@bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule):
-
def extract_kwargs_from_mod(self):
return {}
diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py
index 22a67d1ceccc..e6e511b72fbb 100644
--- a/colossalai/fx/tracer/experimental.py
+++ b/colossalai/fx/tracer/experimental.py
@@ -1,4 +1,3 @@
-import enum
import functools
import inspect
import operator
@@ -10,7 +9,7 @@
from torch.utils._pytree import tree_map
from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta
-from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list
+from colossalai.fx.tracer._tracer_utils import is_element_in_list
from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from colossalai.fx.tracer.registry import (
bias_addition_function,
@@ -24,31 +23,45 @@
from colossalai.fx.profiler import MetaTensor
Target = Union[Callable[..., Any], str]
-Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
- List[Any], # actually Argument
- Dict[str, Any], # actually Argument
- slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
- 'Node',]]
-_CScriptMethod = ['add', 'mul', 'sub', 'div']
+Argument = Optional[
+ Union[
+ Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
+ List[Any], # actually Argument
+ Dict[str, Any], # actually Argument
+ slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
+ "Node",
+ ]
+]
+_CScriptMethod = ["add", "mul", "sub", "div"]
_TorchNewMethod = [
- "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor",
- "finfo"
+ "arange",
+ "zeros",
+ "zeros_like",
+ "ones",
+ "ones_like",
+ "full",
+ "full_like",
+ "empty",
+ "empty_like",
+ "eye",
+ "tensor",
+ "finfo",
]
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
def _truncate_suffix(s: str):
import re
- return re.sub(r'_\d+$', '', s)
+
+ return re.sub(r"_\d+$", "", s)
def default_device():
- return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
@compatibility(is_backward_compatible=False)
class ColoProxy(Proxy):
-
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@@ -100,7 +113,7 @@ def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
- proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
+ proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -125,29 +138,28 @@ def ndim(self):
@property
def device(self):
- proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
+ proxy = self.tracer.create_proxy("call_function", getattr, (self, "device"), {})
proxy.meta_data = self.meta_data.device
return proxy
@property
def dtype(self):
- proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
+ proxy = self.tracer.create_proxy("call_function", getattr, (self, "dtype"), {})
proxy.meta_data = self.meta_data.dtype
return proxy
def to(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "to", (self, *args), {**kwargs})
def cpu(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "cpu", (self, *args), {**kwargs})
def cuda(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "cuda", (self, *args), {**kwargs})
@compatibility(is_backward_compatible=False)
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -160,11 +172,11 @@ def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
- self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+ self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+ return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
@@ -172,7 +184,6 @@ def __repr__(self):
@compatibility(is_backward_compatible=False)
class ColoTracer(Tracer):
-
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self._disable_module_getattr = False
@@ -184,24 +195,28 @@ def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
- def proxy(self, node: Node) -> 'ColoProxy':
+ def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
- def create_proxy(self,
- kind: str,
- target: Target,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
-
+ def create_proxy(
+ self,
+ kind: str,
+ target: Target,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None,
+ proxy_factory_fn: Callable[[Node], "Proxy"] = None,
+ ):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
- if kind == 'placeholder':
- proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
- _truncate_suffix(target), None)
- elif kind == 'get_attr':
+ if kind == "placeholder":
+ proxy.meta_data = (
+ self.meta_args[target]
+ if target in self.meta_args
+ else self.concrete_args.get(_truncate_suffix(target), None)
+ )
+ elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
@@ -211,20 +226,21 @@ def create_proxy(self,
proxy.meta_data = attr_itr
finally:
self._disable_module_getattr = False
- elif kind == 'call_function':
+ elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
+ elif kind == "call_method":
self._disable_module_getattr = True
try:
- if target == '__call__':
+ if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
+ proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
finally:
self._disable_module_getattr = False
- elif kind == 'call_module':
+ elif kind == "call_module":
mod = self.root.get_submodule(target)
self._disable_module_getattr = True
try:
@@ -238,14 +254,15 @@ def create_node(self, *args, **kwargs) -> Node:
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
- node.meta['activation_checkpoint'] = self.act_ckpt_region_count
+ node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node
- def trace(self,
- root: torch.nn.Module,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None,
- meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
-
+ def trace(
+ self,
+ root: torch.nn.Module,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+ meta_args: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Graph:
if meta_args is None:
meta_args = {}
@@ -260,20 +277,19 @@ def trace(self,
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
- non_concrete_arg_names = sig_names - concrete_arg_names
+ sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
- f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
+ f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
+ )
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
@@ -292,7 +308,6 @@ def trace_activation_checkpoint(self, enabled: bool):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
-
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activation checkpoint part
@@ -305,7 +320,8 @@ def forward(ctx, run_function, preserve_rng_state, *args):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
- "We do not implement the backward pass as we only trace the forward pass.")
+ "We do not implement the backward pass as we only trace the forward pass."
+ )
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
@@ -356,10 +372,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
- if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
- kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
- lambda node: ColoProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ColoProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -370,8 +389,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
@@ -389,42 +409,41 @@ def symbolic_trace(
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
- graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
- concrete_args=concrete_args,
- meta_args=tree_map(wrap_fn, meta_args))
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
+ root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
+ )
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
- graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
- concrete_args=concrete_args,
- meta_args=meta_args)
+
+ graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
+ root, concrete_args=concrete_args, meta_args=meta_args
+ )
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
@compatibility(is_backward_compatible=False)
class _TorchTensorOverride(object):
-
def __init__(self, tracer: Tracer):
self.overrides = {}
self.tracer = tracer
def __enter__(self):
-
def wrap_tensor_method(target):
-
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
- isinstance(p, ColoProxy) for p in kwargs.values())
+ isinstance(p, ColoProxy) for p in kwargs.values()
+ )
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.tracer._disable_module_getattr = True
try:
- proxy = self.tracer.create_proxy('call_function', target, args, kwargs)
+ proxy = self.tracer.create_proxy("call_function", target, args, kwargs)
finally:
self.tracer._disable_module_getattr = False
return proxy
@@ -446,11 +465,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
setattr(torch, name, orig)
-def meta_prop_pass(gm: ColoGraphModule,
- root: torch.nn.Module,
- meta_args: Optional[Dict[str, Any]] = None,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None):
-
+def meta_prop_pass(
+ gm: ColoGraphModule,
+ root: torch.nn.Module,
+ meta_args: Optional[Dict[str, Any]] = None,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+):
if meta_args is None:
meta_args = {}
@@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule,
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
for node in gm.graph.nodes:
- node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
- node.kwargs)
+ node._meta_data = _meta_data_computing(
+ meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs
+ )
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
- if kind == 'placeholder':
+ if kind == "placeholder":
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
- elif kind == 'get_attr':
+ elif kind == "get_attr":
attr_itr = root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
meta_out = attr_itr
- elif kind == 'call_function':
+ elif kind == "call_function":
meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
- if target == '__call__':
+ elif kind == "call_method":
+ if target == "__call__":
meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_module':
+ meta_out = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
+ elif kind == "call_module":
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
else:
@@ -603,26 +623,30 @@ def wrap_fn(n):
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
- if 'bias' in kwargs and kwargs['bias'] is not None:
+ if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
else:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target.__name__)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
- handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_method.get(method)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
@@ -631,8 +655,9 @@ def wrap_fn(n):
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
- handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_module.get(mod_type)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
if handle is not None:
handle.generate()
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
index 12c42514895e..75d7b18a067c 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
@@ -5,4 +5,4 @@
@meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
index 042b92c5847a..3475f22e3b19 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
@@ -4,7 +4,7 @@
@meta_patched_function.register(torch.matmul)
-@meta_patched_function.register('matmul') # for built-in op @
+@meta_patched_function.register("matmul") # for built-in op @
def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx
d1 = input.dim()
@@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None):
@meta_patched_function.register(torch.abs)
def torch_abs(input, *, out=None):
- assert out is None, 'out is not supported yet'
- return torch.empty(input.shape, device='meta')
+ assert out is None, "out is not supported yet"
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.bmm)
@@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
- assert out is None, 'saving to out is not supported yet'
- var = torch.empty(1).squeeze(0).to('meta')
- mean = torch.empty(1).squeeze(0).to('meta')
+ assert out is None, "saving to out is not supported yet"
+ var = torch.empty(1).squeeze(0).to("meta")
+ mean = torch.empty(1).squeeze(0).to("meta")
return var, mean
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
index 8500e5c82508..26daf32a2afc 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
@@ -8,7 +8,6 @@
def _ntuple(n, name="parse"):
-
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
@@ -24,21 +23,21 @@ def parse(x):
def _extract_kwargs(kwargs):
- if 'stride' in kwargs:
- stride = kwargs['stride']
+ if "stride" in kwargs:
+ stride = kwargs["stride"]
else:
stride = 1
# TODO: process str type padding
- if 'padding' in kwargs:
- padding = kwargs['padding']
+ if "padding" in kwargs:
+ padding = kwargs["padding"]
else:
padding = 0
- if 'dilation' in kwargs:
- dilation = kwargs['dilation']
+ if "dilation" in kwargs:
+ dilation = kwargs["dilation"]
else:
dilation = 1
- if 'output_padding' in kwargs:
- output_padding = kwargs['output_padding']
+ if "output_padding" in kwargs:
+ output_padding = kwargs["output_padding"]
else:
output_padding = 0
@@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs):
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv2d)
@@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv3d)
@@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose1d)
@@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
l_in = input.shape[-1]
c_out = weight.shape[1]
- l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose2d)
@@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
h_in, w_in = input.shape[-2:]
c_out = weight.shape[1]
- h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
- output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose3d)
@@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
d_in, h_in, w_in = input.shape[-3:]
c_out = weight.shape[1]
- d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
- output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) +
- output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + output_padding[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
index 6d8d864ea29a..27a79f18590a 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
@@ -4,11 +4,7 @@
@meta_patched_function.register(torch.nn.functional.embedding)
-def torch_nn_functional_embedding(input,
- weight,
- padding_idx=None,
- max_norm=None,
- norm_type=2.0,
- scale_grad_by_freq=False,
- sparse=False):
+def torch_nn_functional_embedding(
+ input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
+):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
index e9e7eda6159c..8a6214990830 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
@@ -5,16 +5,11 @@
@meta_patched_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.nn.functional.batch_norm)
-def torch_nn_func_batchnorm(input,
- running_mean,
- running_var,
- weight=None,
- bias=None,
- training=False,
- momentum=0.1,
- eps=1e-05):
- return torch.empty(input.shape, device='meta')
+def torch_nn_func_batchnorm(
+ input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05
+):
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
index 4c171cb10991..7642934a409b 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
@@ -19,9 +19,9 @@ def to_concrete(t):
return t
def _slice_convert(slice_obj):
- attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step}
+ attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step}
new_attrs = _slice_attr_convert(attrs)
- attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step'])
+ attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"])
return slice(*attr_dict_to_tuple)
def _slice_attr_convert(attrs):
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
index b14ff10ce137..c61e1c4dc9e1 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
@@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
- final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
+ final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.repeat_interleave)
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
- assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
- "Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
+ assert isinstance(repeats, int) or isinstance(
+ repeats, torch.Tensor
+ ), "Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape = list(input.shape) if dim is not None else [input.numel()]
dim = dim if dim is not None else 0
@@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None)
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.full)
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
- assert out is None, 'assigning result to out is not supported yet'
- return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)
+ assert out is None, "assigning result to out is not supported yet"
+ return torch.empty(size, device="meta", dtype=dtype, layout=layout, requires_grad=requires_grad)
@meta_patched_function.register(torch.max)
def torch_max(input, dim=None, keepdim=False, *, out=None):
- assert out is None, 'assigning value to out is not supported yet'
+ assert out is None, "assigning value to out is not supported yet"
if dim is not None:
if isinstance(dim, int):
shape = list(input.shape)
shape.pop(dim)
if keepdim:
shape.insert(dim, 1)
- return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape,
- device='meta',
- dtype=input.dtype)
+ return torch.empty(shape, device="meta", dtype=input.dtype), torch.empty(
+ shape, device="meta", dtype=input.dtype
+ )
elif isinstance(dim, torch.Tensor):
# when dim is a 0D or 1D tensor, it will maintain the same shape
num_dims = dim.dim()
if num_dims in [0, 1]:
- return torch.empty_like(input, device='meta')
+ return torch.empty_like(input, device="meta")
else:
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions")
else:
- return torch.empty([], device='meta', dtype=input.dtype)
+ return torch.empty([], device="meta", dtype=input.dtype)
@meta_patched_function.register(torch.Tensor.cpu)
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
index e28e52585fff..3f40ec2a67ee 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
@@ -4,4 +4,4 @@
from .linear import *
from .normalization import *
from .pooling import *
-from .rnn import *
\ No newline at end of file
+from .rnn import *
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
index d03da6588c1c..aa2ede187d37 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
@@ -10,4 +10,4 @@
@meta_patched_module.register(torch.nn.ReLU6)
@meta_patched_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self, input):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
index cf9f3487aac9..35173a68a0be 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
@@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
l_in = input.shape[-1]
c_out = self.out_channels
- l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
+ l_out = math.floor(
+ (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv2d)
@@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
- h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
+ h_out = math.floor(
+ (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv3d)
@@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
- w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
- (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
+ d_out = math.floor(
+ (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ h_out = math.floor(
+ (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose1d)
@@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
l_in = input.shape[-1]
c_out = self.out_channels
- l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose2d)
@@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
- h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose3d)
@@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
- (self.kernel_size[2] - 1) + self.output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[2]
+ - 2 * self.padding[2]
+ + self.dilation[2] * (self.kernel_size[2] - 1)
+ + self.output_padding[2]
+ + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
index 999e33b17c1c..f28647e9caa5 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
@@ -6,4 +6,4 @@
@meta_patched_module.register(torch.nn.Embedding)
def torch_nn_embedding(self, input):
result_shape = input.shape + (self.embedding_dim,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py
index 56f13bf97532..97e6b0e96e83 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py
@@ -6,5 +6,7 @@
@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
last_dim = input.shape[-1]
- assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
+ assert (
+ last_dim == self.in_features
+ ), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch"
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
index c21ff64cf3de..198e72e342b1 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
@@ -23,6 +23,7 @@ def torch_nn_normalize(self, input):
try:
import apex
+
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
index 7ce23fbf7ac9..450586d02f8f 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
@@ -8,7 +8,7 @@
@meta_patched_module.register(torch.nn.AvgPool1d)
def torch_nn_avgpool1d(self, input):
num_dim = input.dim()
- assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1]
@@ -25,13 +25,13 @@ def _convert_int_to_list(item):
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool2d)
def torch_nn_avgpool2d(self, input):
num_dim = input.dim()
- assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:]
@@ -52,13 +52,13 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool3d)
def torch_nn_avgpool3d(self, input):
num_dim = input.dim()
- assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:]
@@ -81,13 +81,13 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool1d)
def torch_nn_maxpool1d(self, input):
num_dim = input.dim()
- assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1]
@@ -105,13 +105,13 @@ def _convert_int_to_list(item):
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool2d)
def torch_nn_maxpool2d(self, input):
num_dim = input.dim()
- assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:]
@@ -133,13 +133,13 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool3d)
def torch_nn_maxpool3d(self, input):
num_dim = input.dim()
- assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:]
@@ -163,7 +163,7 @@ def _convert_int_to_list(item):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
@@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-1]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
@@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-2]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
@@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-3]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
index ee15ca34162e..bfb7ed171186 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
import torch
from ...registry import meta_patched_module
@@ -8,9 +6,11 @@
@meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN)
def torch_nn_rnn(self, input, hx):
- assert input.shape[
- -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch'
- assert hx.shape[
- -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch'
+ assert (
+ input.shape[-1] == self.input_size
+ ), f"Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch"
+ assert (
+ hx.shape[-1] == self.hidden_size
+ ), f"Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch"
d = 2 if self.bidirectional else 1
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx
diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py
index 12fc6de73d44..80b3868bb4fe 100644
--- a/colossalai/fx/tracer/registry.py
+++ b/colossalai/fx/tracer/registry.py
@@ -1,11 +1,9 @@
class PatchRegistry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
self.store[source] = func
return func
@@ -21,8 +19,8 @@ def has(self, source):
return source in self.store
-meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
-meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
-bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
-bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
-bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
+meta_patched_function = PatchRegistry(name="patched_functions_for_meta_execution")
+meta_patched_module = PatchRegistry(name="patched_modules_for_meta_execution")
+bias_addition_function = PatchRegistry(name="patched_function_for_bias_addition")
+bias_addition_module = PatchRegistry(name="patched_module_for_bias_addition")
+bias_addition_method = PatchRegistry(name="patched_method_for_bias_addition")
diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py
index 28965a1b8e74..d9cb587b5d39 100644
--- a/colossalai/fx/tracer/tracer.py
+++ b/colossalai/fx/tracer/tracer.py
@@ -29,7 +29,7 @@
meta_patched_module,
)
-__all__ = ['ColoTracer']
+__all__ = ["ColoTracer"]
class TracerType(enum.Enum):
@@ -103,7 +103,7 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
- if 'bias' in kwargs and kwargs['bias'] is not None:
+ if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
@@ -160,22 +160,27 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
- kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
- lambda node: ParameterProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ParameterProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
- maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
- parameter_proxy_cache)
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_buffers(), parameter_proxy_cache
+ )
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
@@ -190,7 +195,7 @@ def call_module(self, m, forward, args, kwargs):
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
- return self.create_proxy('call_module', module_qualified_name, args, kwargs)
+ return self.create_proxy("call_module", module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
@@ -211,7 +216,6 @@ def _configure_tracer_type(self, tracer_type: TracerType):
raise ValueError(f"Unrecognized tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs):
-
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
meta_out = self.meta_args[target]
return meta_out
@@ -235,8 +239,9 @@ def _meta_data_computing(self, kind, target, args, kwargs):
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter = False
- if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
- torch.nn.parameter.Parameter):
+ if target in (torch.transpose, torch.reshape) and isinstance(
+ args_metas[0], torch.nn.parameter.Parameter
+ ):
convert_to_parameter = True
# fetch patched function
if meta_patched_function.has(target):
@@ -309,10 +314,12 @@ def _meta_data_computing(self, kind, target, args, kwargs):
return meta_out
- def trace(self,
- root: nn.Module,
- concrete_args: Optional[Dict[str, Tensor]] = None,
- meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
+ def trace(
+ self,
+ root: nn.Module,
+ concrete_args: Optional[Dict[str, Tensor]] = None,
+ meta_args: Optional[Dict[str, Tensor]] = None,
+ ) -> Graph:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
@@ -341,9 +348,7 @@ def trace(self,
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
@@ -354,7 +359,8 @@ def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
- f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
+ f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
+ )
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
@@ -363,11 +369,13 @@ def _check_arg_name_valid(names):
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
if not should_be_meta:
- assert not torch.is_tensor(v) or not v.is_meta, \
- f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
+ assert (
+ not torch.is_tensor(v) or not v.is_meta
+ ), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer"
else:
- assert v.is_meta == should_be_meta, \
- f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
+ assert (
+ v.is_meta == should_be_meta
+ ), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer"
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)
@@ -442,7 +450,6 @@ def trace_activation_checkpoint(self, enabled: bool):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
-
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activation checkpoint part
@@ -455,7 +462,8 @@ def forward(ctx, run_function, preserve_rng_state, *args):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
- "We do not implement the backward pass as we only trace the forward pass.")
+ "We do not implement the backward pass as we only trace the forward pass."
+ )
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
@@ -470,12 +478,11 @@ def create_node(self, *args, **kwargs) -> Node:
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
- node.meta['activation_checkpoint'] = self.act_ckpt_region_count
+ node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node
def wrap_tensor_constructor_method(target):
-
def look_for_proxy(*args, **kwargs):
# find in pos vars
for arg in args:
@@ -518,12 +525,10 @@ def wrapper(*args, **kwargs):
for method in magic_methods:
def _scope(method):
-
def impl(*args, **kwargs):
-
tracer = args[0].tracer
target = getattr(operator, method)
- proxy = tracer.create_proxy('call_function', target, args, kwargs)
+ proxy = tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
@@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name):
def impl(self, rhs):
target = getattr(operator, orig_method_name)
- proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
+ proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py
index e467b4c73e6b..112b920ba158 100644
--- a/colossalai/inference/tensor_parallel/__init__.py
+++ b/colossalai/inference/tensor_parallel/__init__.py
@@ -1,4 +1,4 @@
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager
-__all__ = ['MemoryManager', 'TPInferEngine']
+__all__ = ["MemoryManager", "TPInferEngine"]
diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py
index 2bff9317283e..ac185f1b6529 100644
--- a/colossalai/inference/tensor_parallel/batch_infer_state.py
+++ b/colossalai/inference/tensor_parallel/batch_infer_state.py
@@ -1,6 +1,5 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
-from typing import Any
import torch
@@ -31,7 +30,7 @@ class BatchInferState:
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None
- device: torch.device = torch.device('cuda')
+ device: torch.device = torch.device("cuda")
@property
def total_token_num(self):
@@ -43,13 +42,15 @@ def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager
@staticmethod
- def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
- alloc_mem_index: torch.Tensor):
- """ in-place update block loc mapping based on the sequence length of the inputs in current bath"""
+ def init_block_loc(
+ b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
+ ):
+ """in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
- b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
- cur_seq_len]
+ b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
+ start_index : start_index + cur_seq_len
+ ]
start_index += cur_seq_len
return
diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py
index a5a55702ade0..1335f13d66b8 100644
--- a/colossalai/inference/tensor_parallel/engine.py
+++ b/colossalai/inference/tensor_parallel/engine.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, List, Optional, Union
import torch
import torch.nn as nn
@@ -15,7 +15,7 @@
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
-_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
+_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"]
class TPInferEngine:
@@ -39,14 +39,16 @@ class TPInferEngine:
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""
- def __init__(self,
- model: nn.Module,
- shard_config: ShardConfig,
- max_batch_size: int,
- max_input_len: int,
- max_output_len: int,
- dtype: torch.dtype = torch.float16,
- device: str = 'cuda') -> None:
+ def __init__(
+ self,
+ model: nn.Module,
+ shard_config: ShardConfig,
+ max_batch_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ dtype: torch.dtype = torch.float16,
+ device: str = "cuda",
+ ) -> None:
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
@@ -63,7 +65,7 @@ def __init__(self,
self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers
- self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
+ self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
self.shard_config = shard_config
@@ -74,9 +76,10 @@ def __init__(self,
def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
- self.head_num //= self.tp_size # update sharded number of heads
- self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
- self.layer_num)
+ self.head_num //= self.tp_size # update sharded number of heads
+ self.cache_manager = MemoryManager(
+ self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
+ )
def _optimize_model(self, model: nn.Module) -> None:
"""
@@ -90,7 +93,7 @@ def _optimize_model(self, model: nn.Module) -> None:
self._shard_model_by(shardformer, model)
def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
- """ Prepare the engine with a given ShardConfig.
+ """Prepare the engine with a given ShardConfig.
Args:
shard_config (ShardConfig): shard config given to specify settings of the engine.
@@ -118,9 +121,10 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
return shard_config
def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
- """ Shard original model by the given ShardFormer and store the sharded model. """
- assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
- "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
+ """Shard original model by the given ShardFormer and store the sharded model."""
+ assert (
+ self.tp_size == shardformer.shard_config.tensor_parallel_size
+ ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)
@@ -147,7 +151,7 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor],
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].cuda()
- if 'max_new_tokens' not in generate_kwargs:
+ if "max_new_tokens" not in generate_kwargs:
generate_kwargs.update(max_new_tokens=self.max_output_len)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
@@ -176,18 +180,18 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
attention_mask = None
if isinstance(inputs, (BatchEncoding, dict)):
- input_ids_list = inputs['input_ids']
- attention_mask = inputs['attention_mask']
+ input_ids_list = inputs["input_ids"]
+ attention_mask = inputs["attention_mask"]
else:
input_ids_list = inputs
- if isinstance(input_ids_list[0], int): # for a single input
+ if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
- seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
- seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
@@ -210,10 +214,10 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
- block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
+ block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
- batch_infer_state.seq_len = seq_lengths.to('cuda')
- batch_infer_state.start_loc = seq_start_indexes.to('cuda')
+ batch_infer_state.seq_len = seq_lengths.to("cuda")
+ batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
@@ -248,7 +252,7 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer
- setattr(model, 'infer_state', batch_infer_state)
+ setattr(model, "infer_state", batch_infer_state)
outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
@@ -262,14 +266,15 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch
# as an arg into model.forward.
# It requires rewriting model generate and replacing model forward.
@torch.no_grad()
- def _generate_by_pass_infer_state(self,
- input_tokens,
- max_out_length: int,
- generation_config: Optional[GenerationConfig] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
-
+ def _generate_by_pass_infer_state(
+ self,
+ input_tokens,
+ max_out_length: int,
+ generation_config: Optional[GenerationConfig] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ **model_kwargs,
+ ) -> torch.Tensor:
raise NotImplementedError("generate by passing BatchInferState is not implemented.")
# might want to use in rewritten generate method: use after model.forward
diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py
index 274c01841279..e74a3a491a7b 100644
--- a/colossalai/inference/tensor_parallel/kvcache_manager.py
+++ b/colossalai/inference/tensor_parallel/kvcache_manager.py
@@ -19,13 +19,15 @@ class MemoryManager:
device: device used to store the key and value cache
"""
- def __init__(self,
- size: int,
- dtype: torch.dtype,
- head_num: int,
- head_dim: int,
- layer_num: int,
- device: torch.device = torch.device('cuda')):
+ def __init__(
+ self,
+ size: int,
+ dtype: torch.dtype,
+ head_num: int,
+ head_dim: int,
+ layer_num: int,
+ device: torch.device = torch.device("cuda"),
+ ):
self.logger = logging.get_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
@@ -33,13 +35,13 @@ def __init__(self,
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
def _init_mem_states(self, size, device):
- """ Initialize tensors used to manage memory states """
+ """Initialize tensors used to manage memory states"""
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
- """ Initialize key buffer and value buffer on specified device """
+ """Initialize key buffer and value buffer on specified device"""
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
@@ -49,10 +51,9 @@ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
@torch.no_grad()
def alloc(self, required_size):
- """ allocate space of required_size by providing indexes representing available physical spaces """
+ """allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size:
- self.logger.warning(f"No enough cache: required_size {required_size} "
- f"left_size {self.available_size}")
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
@@ -63,23 +64,25 @@ def alloc(self, required_size):
@torch.no_grad()
def alloc_contiguous(self, required_size):
- """ allocate contiguous space of required_size """
+ """allocate contiguous space of required_size"""
if required_size > self.available_size:
- self.logger.warning(f"No enough cache: required_size {required_size} "
- f"left_size {self.available_size}")
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
- loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
- 1] + self.mem_state[0:sum_size -
- required_size + 1]
- can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
+ loc_sums = (
+ self.mem_cum_sum[required_size - 1 :]
+ - self.mem_cum_sum[0 : sum_size - required_size + 1]
+ + self.mem_state[0 : sum_size - required_size + 1]
+ )
+ can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
- self.logger.info(f"No enough contiguous cache: required_size {required_size} "
- f"left_size {self.available_size}")
+ self.logger.info(
+ f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
+ )
return None
start_loc = can_used_loc[0]
- select_index = self.indexes[start_loc:start_loc + required_size]
+ select_index = self.indexes[start_loc : start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
@@ -88,13 +91,13 @@ def alloc_contiguous(self, required_size):
@torch.no_grad()
def free(self, free_index):
- """ free memory by updating memory states based on given indexes """
+ """free memory by updating memory states based on given indexes"""
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1
@torch.no_grad()
def free_all(self):
- """ free all memory by updating memory states """
+ """free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.past_key_values_length = 0
diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py
index 7a98b033f37e..27cec5452ece 100644
--- a/colossalai/inference/tensor_parallel/modeling/__init__.py
+++ b/colossalai/inference/tensor_parallel/modeling/__init__.py
@@ -1,4 +1,4 @@
from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards
-__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']
+__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"]
diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py
index ba5eadc92be8..27a26caabefa 100644
--- a/colossalai/inference/tensor_parallel/modeling/bloom.py
+++ b/colossalai/inference/tensor_parallel/modeling/bloom.py
@@ -1,6 +1,6 @@
import math
import warnings
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -31,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16):
"""
def get_slopes_power_of_2(n):
- start = 2**(-(2**-(math.log2(n) - 3)))
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
def get_slopes(n):
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
- closest_power_of_2 = 2**math.floor(math.log2(n))
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2)
- slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2]
+ slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
return slopes_combined
slopes = get_slopes(n_head)
@@ -72,7 +72,6 @@ def bloom_model_forward(
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
-
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
@@ -86,8 +85,9 @@ def bloom_model_forward(
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else self.config.output_hidden_states)
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -122,14 +122,15 @@ def bloom_model_forward(
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
use_cache = False
# NOTE determine if BatchInferState is passed in via arg
# if not, get the attr binded to the model
# We might wantto remove setattr later
if infer_state is None:
- assert hasattr(self, 'infer_state')
+ assert hasattr(self, "infer_state")
infer_state = self.infer_state
# Compute alibi tensor: check build_alibi_tensor documentation
@@ -146,10 +147,11 @@ def bloom_model_forward(
if use_cache and seq_length != 1:
# prefill stage
- infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
- BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
- infer_state.context_mem_index)
+ BatchInferState.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
@@ -182,8 +184,11 @@ def bloom_model_forward(
# alibi = generate_alibi(self.num_heads).contiguous().cuda()
tp_size = dist.get_world_size()
curr_tp_rank = dist.get_rank()
- alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) *
- self.num_heads].cuda()
+ alibi = (
+ generate_alibi(self.num_heads * tp_size)
+ .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
+ .cuda()
+ )
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
@@ -197,7 +202,6 @@ def bloom_model_forward(
if self.gradient_checkpointing and self.training:
# NOTE: currently our KV cache manager does not handle this condition
def create_custom_forward(module):
-
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
@@ -250,32 +254,34 @@ def custom_forward(*inputs):
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
- past_key_values=presents, # should always be (None, None, ..., None)
+ past_key_values=presents, # should always be (None, None, ..., None)
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@staticmethod
- def bloom_for_causal_lm_forward(self: BloomForCausalLM,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- infer_state: Optional[BatchInferState] = None,
- **deprecated_arguments):
+ def bloom_for_causal_lm_forward(
+ self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments,
+ ):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
- logger = logging.get_logger(__name__)
+ logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
@@ -289,17 +295,19 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer,
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- infer_state=infer_state)
+ transformer_outputs = BloomInferenceForwards.bloom_model_forward(
+ self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ )
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@@ -314,8 +322,9 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
- shift_labels.view(batch_size * seq_length))
+ loss = loss_fct(
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
+ )
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@@ -353,11 +362,13 @@ def bloom_for_causal_lm_prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}
- model_inputs.update({
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- })
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
return model_inputs
@staticmethod
@@ -416,7 +427,7 @@ def bloom_block_forward(
else:
outputs = (output,) + outputs[1:]
- return outputs # hidden_states, present, attentions
+ return outputs # hidden_states, present, attentions
@staticmethod
def bloom_attention_forward(
@@ -431,20 +442,19 @@ def bloom_attention_forward(
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
-
- fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape
- k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
- v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
- if layer_id == 0: # once per model.forward
- infer_state.cache_manager.past_key_values_length += q_length # += 1
+ if layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_length # += 1
if infer_state.is_context_stage:
# context process
@@ -471,9 +481,11 @@ def bloom_attention_forward(
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_v = infer_state.cache_manager.value_buffer[layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_k.copy_(k)
cache_v.copy_(v)
else:
@@ -486,8 +498,17 @@ def bloom_attention_forward(
b_loc = infer_state.block_loc
b_seq_len = infer_state.seq_len
output = torch.empty_like(q)
- token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc,
- b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi)
+ token_attention_fwd(
+ q,
+ mem_manager.key_buffer[layer_id],
+ mem_manager.value_buffer[layer_id],
+ output,
+ b_loc,
+ b_start_loc,
+ b_seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ alibi,
+ )
context_layer = output.view(batch_size, q_length, H * D_HEAD)
@@ -504,8 +525,8 @@ def bloom_attention_forward(
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
- context_layer[:, :, int(i * slices):int((i + 1) * slices)],
- self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py
index 07b73a6f4ca6..4795162f1980 100644
--- a/colossalai/inference/tensor_parallel/modeling/llama.py
+++ b/colossalai/inference/tensor_parallel/modeling/llama.py
@@ -1,6 +1,5 @@
from typing import List, Optional, Tuple
-import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
@@ -15,6 +14,7 @@
try:
from vllm import layernorm_ops, pos_encoding_ops
+
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
@@ -29,17 +29,17 @@
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
@@ -71,8 +71,7 @@ def llama_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
-
- batch_size = input_ids.shape[0] # input_ids.shape[0]
+ batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state
@@ -103,10 +102,11 @@ def llama_model_forward(
if use_cache and seq_length != 1:
# NOTE assuem prefill stage
# allocate memory block
- infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
- infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
- infer_state.context_mem_index)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
@@ -129,20 +129,20 @@ def llama_model_forward(
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(past_key_values_length,
- seq_length + past_key_values_length,
- dtype=torch.long,
- device=device)
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if infer_state.is_context_stage:
-
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
- position_ids.view(-1).shape[0], -1)
+ position_ids.view(-1).shape[0], -1
+ )
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
- position_ids.view(-1).shape[0], -1)
+ position_ids.view(-1).shape[0], -1
+ )
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
@@ -153,12 +153,13 @@ def llama_model_forward(
# embed positions
if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_length_with_past),
- dtype=torch.bool,
- device=inputs_embeds.device)
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
- attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
- past_key_values_length)
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
hidden_states = inputs_embeds
@@ -216,7 +217,6 @@ def llama_decoder_layer_forward(
use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
-
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -261,7 +261,6 @@ def llama_flash_attn_kvcache_forward(
use_cache: bool = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size()
@@ -277,8 +276,8 @@ def llama_flash_attn_kvcache_forward(
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
- if infer_state.decode_layer_id == 0: # once per model.forward
- infer_state.cache_manager.past_key_values_length += q_len # seq_len
+ if infer_state.decode_layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
@@ -299,38 +298,62 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
# first token generation
# copy key and value calculated in current step to memory manager
- _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index,
- infer_state.cache_manager)
+ _copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
attn_output = torch.empty_like(query_states)
- llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc,
- infer_state.seq_len, infer_state.cache_manager.past_key_values_length)
+ llama_context_attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ )
else:
-
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
- infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
cache_k.copy_(key_states)
cache_v.copy_(value_states)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
- _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states,
- infer_state.decode_mem_index, infer_state.cache_manager)
+ _copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
- token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
- infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output,
- infer_state.block_loc, infer_state.start_loc, infer_state.seq_len,
- infer_state.cache_manager.past_key_values_length)
+ token_attention_fwd(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ )
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@@ -341,7 +364,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
def get_llama_vllm_rmsnorm_forward():
-
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py
index 48f8db62c32a..fcb1b6a3bd8f 100644
--- a/colossalai/inference/tensor_parallel/policies/__init__.py
+++ b/colossalai/inference/tensor_parallel/policies/__init__.py
@@ -1,4 +1,4 @@
from .bloom import BloomModelInferPolicy
from .llama import LlamaModelInferPolicy
-__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']
+__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"]
diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py
index cae43aa20421..2d18a3922c1e 100644
--- a/colossalai/inference/tensor_parallel/policies/bloom.py
+++ b/colossalai/inference/tensor_parallel/policies/bloom.py
@@ -9,6 +9,7 @@
try:
from colossalai.kernel.triton import layer_norm
+
HAS_TRITON_NORM = True
except:
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
@@ -27,40 +28,40 @@ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
class BloomModelInferPolicy(BloomForCausalLMPolicy):
-
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
+
policy = super().module_policy()
# NOTE set inference mode to shard config
self.shard_config._infer()
method_replacement = {
- 'forward': BloomInferenceForwards.bloom_for_causal_lm_forward,
- 'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
+ "forward": BloomInferenceForwards.bloom_for_causal_lm_forward,
+ "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation,
}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=BloomForCausalLM)
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=BloomForCausalLM
+ )
- method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
+ method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
- method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
+ method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
- method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=BloomAttention)
+ method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=BloomAttention
+ )
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
- method_replacement = {'forward': partial(infer_method)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LayerNorm)
+ method_replacement = {"forward": partial(infer_method)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LayerNorm
+ )
return policy
diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py
index 4844415d612c..9bbb547dbcae 100644
--- a/colossalai/inference/tensor_parallel/policies/llama.py
+++ b/colossalai/inference/tensor_parallel/policies/llama.py
@@ -10,6 +10,7 @@
try:
from colossalai.kernel.triton import rmsnorm_forward
+
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
@@ -28,7 +29,6 @@ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
-
def __init__(self) -> None:
super().__init__()
@@ -37,20 +37,20 @@ def module_policy(self):
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
- method_replacement = {'forward': partial(infer_forward)}
+ method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
- method_replacement = {'forward': partial(infer_forward)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LlamaDecoderLayer)
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
+ )
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
- method_replacement = {'forward': partial(infer_forward)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LlamaAttention)
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaAttention
+ )
infer_forward = None
if HAS_TRITON_RMSNORM:
@@ -60,9 +60,9 @@ def module_policy(self):
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
- method_replacement = {'forward': partial(infer_forward)}
- self.append_or_create_method_replacement(description=method_replacement,
- policy=policy,
- target_key=LlamaRMSNorm)
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaRMSNorm
+ )
return policy
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index b8718abc80bd..aac57d34a2c1 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -14,15 +14,17 @@
from colossalai.utils import set_device, set_seed
-def launch(config: Union[str, Path, Config, Dict],
- rank: int,
- world_size: int,
- host: str,
- port: int,
- backend: str = 'nccl',
- local_rank: int = None,
- seed: int = 1024,
- verbose: bool = True):
+def launch(
+ config: Union[str, Path, Config, Dict],
+ rank: int,
+ world_size: int,
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ local_rank: int = None,
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
@@ -46,7 +48,7 @@ def launch(config: Union[str, Path, Config, Dict],
warnings.warn("`config` is deprecated and will be removed soon.")
# init default process group
- init_method = f'tcp://[{host}]:{port}'
+ init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
@@ -58,15 +60,17 @@ def launch(config: Union[str, Path, Config, Dict],
if verbose:
logger = get_dist_logger()
- logger.info(f'Distributed environment is initialized, world size: {dist.get_world_size()}', ranks=[0])
+ logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
-def launch_from_slurm(config: Union[str, Path, Config, Dict],
- host: str,
- port: int,
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+def launch_from_slurm(
+ config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
@@ -79,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['SLURM_PROCID'])
- world_size = int(os.environ['SLURM_NPROCS'])
+ rank = int(os.environ["SLURM_PROCID"])
+ world_size = int(os.environ["SLURM_NPROCS"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
)
- launch(config=config,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def launch_from_openmpi(config: Union[str, Path, Config, Dict],
- host: str,
- port: int,
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+ launch(
+ config=config,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
+
+
+def launch_from_openmpi(
+ config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
@@ -114,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
- local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
- world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
+ local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
+ world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
)
- launch(config=config,
- local_rank=local_rank,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def launch_from_torch(config: Union[str, Path, Config, Dict],
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+ launch(
+ config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
+
+
+def launch_from_torch(
+ config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
+):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
@@ -147,22 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['RANK'])
- local_rank = int(os.environ['LOCAL_RANK'])
- world_size = int(os.environ['WORLD_SIZE'])
- host = os.environ['MASTER_ADDR']
- port = int(os.environ['MASTER_PORT'])
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ host = os.environ["MASTER_ADDR"]
+ port = int(os.environ["MASTER_PORT"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
- launch(config=config,
- local_rank=local_rank,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
+ launch(
+ config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py
index 1c3199fc1aff..98b21c9c02c1 100644
--- a/colossalai/interface/__init__.py
+++ b/colossalai/interface/__init__.py
@@ -1,4 +1,4 @@
from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper
-__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
+__all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"]
diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py
index 7b3d9435d255..58df09b853ee 100644
--- a/colossalai/interface/model.py
+++ b/colossalai/interface/model.py
@@ -26,11 +26,9 @@ def forward(self, *args, **kwargs):
class AMPModelMixin:
- """This mixin class defines the interface for AMP training.
- """
+ """This mixin class defines the interface for AMP training."""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
- pass
diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py
index bc270b1d9c89..95d11087bece 100644
--- a/colossalai/interface/optimizer.py
+++ b/colossalai/interface/optimizer.py
@@ -22,7 +22,7 @@ def parameters(self):
params = []
for group in self.param_groups:
- params += group['params']
+ params += group["params"]
return params
@property
@@ -82,12 +82,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
"""
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2.0,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> Tensor:
+ def clip_grad_by_norm(
+ self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs,
+ ) -> Tensor:
"""
Clips gradient norm of an iterable of parameters.
@@ -113,7 +115,8 @@ def scale_loss(self, loss: Tensor):
loss (Tensor): The loss to be scaled.
"""
raise NotImplementedError(
- "The method scale_loss is only available for optimizers with mixed precision training")
+ "The method scale_loss is only available for optimizers with mixed precision training"
+ )
def unscale_grad(self):
"""
@@ -122,7 +125,8 @@ def unscale_grad(self):
Note: Only available for optimizers with mixed precision training.
"""
raise NotImplementedError(
- "The method unscale_grad is only available for optimizers with mixed precision training")
+ "The method unscale_grad is only available for optimizers with mixed precision training"
+ )
def unwrap(self):
"""
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index e0136d86e561..f8a974b5fb26 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -4,6 +4,10 @@
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [
- 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
- 'AttnMaskType'
+ "LayerNorm",
+ "MultiHeadAttention",
+ "FusedScaleMaskSoftmax",
+ "ScaledUpperTriangMaskedSoftmax",
+ "ColoAttention",
+ "AttnMaskType",
]
diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h
index 00066dc95475..a62beef91a8a 100644
--- a/colossalai/kernel/cuda_native/csrc/compat.h
+++ b/colossalai/kernel/cuda_native/csrc/compat.h
@@ -7,4 +7,4 @@
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
-#endif
\ No newline at end of file
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
index 26efa2ad6f31..9a6a8ebc3983 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
@@ -1,7 +1,6 @@
#include
#include
-
#include "cuda_util.h"
/* GPU function guard */
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
index a39a6dae0f7f..ce0b017f12e1 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
@@ -1,1002 +1,1002 @@
-#include
-#include
-
-#include "kernels.h"
-
-#include
-
-
-namespace cg = cooperative_groups;
-
-curandStatePhilox4_32_10_t *curandstate;
-
-/**
- * @brief element-wise activation function on device, like Relu, Gelu
- *
- * @tparam enum class ActivationType, kRelu, kGelu
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape and type with input
- */
-template
-__forceinline__ __device__ T activation_kernel(T x);
-
-template <>
-__device__ float activation_kernel(float x) {
- float cdf =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
- return x * cdf;
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 val) {
- __half2 val_pow3 = __hmul2(val, __hmul2(val, val));
- float2 tmp_pow = __half22float2(val_pow3);
- float2 tmp = __half22float2(val);
-
- tmp.x =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
- tmp.y =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
- return __hmul2(val, __float22half2_rn(tmp));
-}
-
-template <>
-__device__ float activation_kernel(float x) {
- return fmaxf(x, 0);
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 x) {
- return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
- fmaxf(0.f, __half2float(x.y)));
-}
-
-/**
- * @brief element-wise activation backward function on device
- *
- * @tparam enum class ActivationType
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape of input
- */
-template
-__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * (dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ __half activation_bwd_kernel(
- __half grad, __half x_half) {
- float x = __half2float(x_half);
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * __float2half(dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- return x > 0.f ? grad : 0.f;
-}
-
-template <>
-__device__ __half
-activation_bwd_kernel(__half grad, __half x) {
- const __half half_zero = __float2half(0.f);
- return x > half_zero ? grad : half_zero;
-}
-
-template <>
-__device__ __half2 activation_bwd_kernel(
- __half2 grad2, __half2 x_half2) {
- const __half half_zero = __float2half(0.f);
- return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
- x_half2.y > half_zero ? grad2.y : half_zero);
-}
-
-/**
- * @brief init curand states in global memory
- *
- * @thread grid_dim * block*dim to suuport any size of states
- * @param state persistant curand states
- * @param seed seed to init states
- * @return void
- */
-__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
- int seed) {
- /* Each thread gets same seed, a different sequence
- number, no offset */
- int id = threadIdx.x + blockIdx.x * blockDim.x;
- curand_init(seed, id, 0, &state[id]);
-}
-
-void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
- cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
- int grid_dim = total_count >> 9;
- curand_init_kernel<<>>(
- curandstate, std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
-}
-
-/**
- * @brief element-wise dropout, store dropped position in mask, it's not
- * in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out any size of float and __half
- * @param in same with out
- * @param mask uint8 type, same size with out
- * @param seed seed to curand
- * @return void
- */
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- float *__restrict__ out,
- const float *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
-
- float4 input4 = data4[i];
- float4 res4;
- res4.x = input4.x * scale * m[0];
- res4.y = input4.y * scale * m[1];
- res4.z = input4.z * scale * m[2];
- res4.w = input4.w * scale * m[3];
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- __half *__restrict__ out,
- const __half *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-/**
- * @brief element-wise dropout backward with dropout mask, it's
- * not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param in any size of float and __half
- * @param mask uint8 type, same size with in
- * @return void
- */
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- float *out, const float *in,
- const uint8_t *__restrict__ mask) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *in4 = reinterpret_cast(in);
- const uint32_t *mask4 = reinterpret_cast(mask);
-
- uint32_t *m4 = reinterpret_cast(m);
- m4[0] = mask4[i];
-
- float4 input4 = in4[i];
- float4 res4;
- res4.x = input4.x * scale * static_cast(m[0]);
- res4.y = input4.y * scale * static_cast(m[1]);
- res4.z = input4.z * scale * static_cast(m[2]);
- res4.w = input4.w * scale * static_cast(m[3]);
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- __half *out, const __half *in,
- const uint8_t *__restrict__ mask) {
- const __half scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *vals_float4 = reinterpret_cast(in);
- const uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- uint64_t *m8 = reinterpret_cast(m);
- m8[0] = mask8[i];
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- out4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout(float *out, const float *vals, uint8_t *mask,
- int total_count, float ratio, cudaStream_t stream,
- bool backward) {
- int grid_dim = total_count >> 12;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-template <>
-void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
- int total_count, float ratio,
- cudaStream_t stream, bool backward) {
- int grid_dim = total_count >> 13;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-/**
- * @brief fused bias, dropout, and residual at the end of Attention and FFN,
- * store dropped position in mask, it's not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param residual [batch_size, seq_len, hidden_size], float and __half
- * @param seed seed to curand
- * @param hidden_size hidden size
- * @return void
- */
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const float *__restrict__ residual,
- const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 output4;
-
- output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
- output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
- output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
- output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
-
- out4[i] = output4;
-}
-
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const __half *__restrict__ residual,
- const int seed, const int hidden_size) {
- const __half scale = 1. / (1. - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = static_cast(rand.x > ratio);
- m[5] = static_cast(rand.y > ratio);
- m[6] = static_cast(rand.z > ratio);
- m[7] = static_cast(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = m8[0];
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
- const __half2 *res_half2 = reinterpret_cast(&res4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] =
- __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
- out_half2[1] =
- __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
- out_half2[2] =
- __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
- out_half2[3] =
- __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_res_bias(float *out, const float *vals,
- uint8_t *mask, const float *bias,
- const float *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 12;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
- uint8_t *mask, const __half *bias,
- const __half *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 13;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias and dropout backward at the end of Attention and FFN
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, float *__restrict__ in_grad,
- float *__restrict__ bias_grad, const float *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- // every block generate 8 bias result
- __shared__ float tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- float local_sum = 0;
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- float val = out_grad[idx];
- val *= scale * static_cast(mask[idx]);
- local_sum += val;
- in_grad[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- float sum = 0;
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad[pos] = tile[0][threadIdx.x];
- }
-}
-
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, __half *__restrict__ in_grad,
- __half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
- __shared__ __half2 tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
- const __half2 *out_grad2 = reinterpret_cast(out_grad);
- __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- __half2 local_sum = __float2half2_rn(0.f);
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- __half2 val = out_grad2[idx];
- __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
- val *= scale * m2;
- local_sum += val;
- in_grad2[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- __half2 sum = __float2half2_rn(0.f);
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad2[pos] = tile[0][threadIdx.x];
- }
-}
-
-template
-void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
- const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream) {
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template <>
-void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
- const __half *out_grad, const uint8_t *mask,
- int row_size, int dim, float ratio,
- cudaStream_t stream) {
- dim >>= 1;
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
- const float *out_grad,
- const uint8_t *mask, int row_size,
- int dim, float ratio,
- cudaStream_t stream);
-
-/**
- * @brief fused bias, activation, and dropout at the end of first ffn
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @tparam act_type activation function, like kRelu, kGelu
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param seed seed to curand
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 output4;
-
- output4.x =
- activation_kernel(input4.x + b4.x) * scale * m[0];
- output4.y =
- activation_kernel(input4.y + b4.y) * scale * m[1];
- output4.z =
- activation_kernel(input4.z + b4.z) * scale * m[2];
- output4.w =
- activation_kernel(input4.w + b4.w) * scale * m[3];
-
- out4[i] = output4;
-}
-
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
-
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(
- activation_kernel(__hadd2(val_half2[0], b_half2[0])),
- scale_mask_1);
- out_half2[1] = __hmul2(
- activation_kernel(__hadd2(val_half2[1], b_half2[1])),
- scale_mask_2);
- out_half2[2] = __hmul2(
- activation_kernel(__hadd2(val_half2[2], b_half2[2])),
- scale_mask_3);
- out_half2[3] = __hmul2(
- activation_kernel(__hadd2(val_half2[3], b_half2[3])),
- scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias, activation, and dropout backward
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @tparam act_type kRelu
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_bwd_kernel(
- const int row_size, const float ratio, T *in_grad,
- T *__restrict__ bias_grad, const T *__restrict__ input,
- const T *__restrict__ bias, const T *out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- __shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-
- int stride = hidden_size * WARP_SIZE;
- float local_sum = 0;
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- if (col_idx < hidden_size) {
- for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
- float val = out_grad[idx];
- float in = input[idx];
- float b = bias[idx % hidden_size];
- val = activation_bwd_kernel(
- val * scale * static_cast(mask[idx]), in + b);
- local_sum += val;
- in_grad[idx] = val;
- idx += stride;
- }
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
- float sum = tile[threadIdx.y][threadIdx.x];
- __syncthreads();
-
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
- __syncthreads();
-
- if (threadIdx.y == 0) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
- bias_grad[pos] = tile[0][threadIdx.x];
- }
-}
-
-// @brief fused bias, activation, and dropout backward
-// It is deprecated for precision reason. Keep it for future optimization.
-//
-// template
-// __global__ void ls_dropout_act_bias_bwd_kernel(
-// const int row_size, const float ratio, __half * in_grad,
-// __half *__restrict__ bias_grad, const __half *__restrict__ input, const
-// __half *__restrict__ bias, const __half * out_grad, const uint8_t
-// *__restrict__ mask, const int hidden_size) {
-// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
-// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
-
-// cg::thread_block b = cg::this_thread_block();
-// cg::thread_block_tile g = cg::tiled_partition(b);
-
-// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
-// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
-// const __half2 *out_grad2 = reinterpret_cast(out_grad);
-// const __half2 *input2 = reinterpret_cast(input);
-// const __half2 *bias2 = reinterpret_cast(bias);
-
-// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-
-// int stride = hidden_size * WARP_SIZE;
-// __half2 local_sum = __float2half2_rn(0.f);
-
-// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
-// if (col_idx < hidden_size) {
-// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
-// __half2 val = out_grad2[idx];
-// __half2 in2 = input2[idx];
-// __half2 b2 = bias2[idx % hidden_size ];
-// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
-// val = activation_bwd_kernel(val * scale
-// *
-// m2,
-// in2+b2);
-// local_sum += val;
-// in_grad2[idx] = val;
-// idx += stride;
-// }
-// }
-
-// tile[threadIdx.x][threadIdx.y] = local_sum;
-// __syncthreads();
-// __half2 sum = tile[threadIdx.y][threadIdx.x];
-// __syncthreads();
-
-// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
-// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
-// __syncthreads();
-
-// if (threadIdx.y == 0) {
-// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-// bias_grad2[pos] = tile[0][threadIdx.x];
-// }
-// }
-
-template
-void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
- const T *bias, const T *out_grad,
- const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream) {
- dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
- dim3 block_dim(WARP_SIZE, WARP_SIZE);
- ls_dropout_act_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim);
-}
-
-// template <>
-// void launch_ls_dropout_act_bias_bwd(
-// __half *in_grad, __half *bias_grad,const __half *input, const __half
-// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
-// dim, float ratio, cudaStream_t stream) {
-// dim >>= 1;
-// dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
-// dim3 block_dim(WARP_SIZE, WARP_SIZE);
-// ls_dropout_act_bias_bwd_kernel
-// <<>>(row_size, ratio, in_grad,
-// bias_grad,
-// input, bias,out_grad, mask, dim);
-// }
-
-template void launch_ls_dropout_act_bias_bwd(
- float *in_grad, float *bias_grad, const float *input, const float *bias,
- const float *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
- const __half *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- float *in_grad, float *bias_grad, const float *input, const float *bias,
- const float *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
- const __half *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
+#include
+#include
+
+#include "kernels.h"
+
+#include
+
+
+namespace cg = cooperative_groups;
+
+curandStatePhilox4_32_10_t *curandstate;
+
+/**
+ * @brief element-wise activation function on device, like Relu, Gelu
+ *
+ * @tparam enum class ActivationType, kRelu, kGelu
+ * @tparam input type
+ * @param any shape of float and __half2
+ * @return same shape and type with input
+ */
+template
+__forceinline__ __device__ T activation_kernel(T x);
+
+template <>
+__device__ float activation_kernel(float x) {
+ float cdf =
+ 0.5f *
+ (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
+ return x * cdf;
+}
+
+template <>
+__device__ __half2
+activation_kernel(__half2 val) {
+ __half2 val_pow3 = __hmul2(val, __hmul2(val, val));
+ float2 tmp_pow = __half22float2(val_pow3);
+ float2 tmp = __half22float2(val);
+
+ tmp.x =
+ 0.5f *
+ (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
+ tmp.y =
+ 0.5f *
+ (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
+ return __hmul2(val, __float22half2_rn(tmp));
+}
+
+template <>
+__device__ float activation_kernel(float x) {
+ return fmaxf(x, 0);
+}
+
+template <>
+__device__ __half2
+activation_kernel(__half2 x) {
+ return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
+ fmaxf(0.f, __half2float(x.y)));
+}
+
+/**
+ * @brief element-wise activation backward function on device
+ *
+ * @tparam enum class ActivationType
+ * @tparam input type
+ * @param any shape of float and __half2
+ * @return same shape of input
+ */
+template
+__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
+
+template <>
+__device__ float activation_bwd_kernel(float grad,
+ float x) {
+ const float sqrt_param = 0.79788456080286535587989211986876f;
+ const float mul_param = 0.044715;
+
+ float x2mul = x * x * mul_param;
+ float tan_h = tanhf(sqrt_param * (x + x * x2mul));
+ float dg1 = 0.5f * (1.0f + tan_h);
+ float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
+ float dg3 = dg2 * 3 * x2mul;
+ return grad * (dg1 + dg2 + dg3);
+}
+
+template <>
+__device__ __half activation_bwd_kernel(
+ __half grad, __half x_half) {
+ float x = __half2float(x_half);
+ const float sqrt_param = 0.79788456080286535587989211986876f;
+ const float mul_param = 0.044715;
+
+ float x2mul = x * x * mul_param;
+ float tan_h = tanhf(sqrt_param * (x + x * x2mul));
+ float dg1 = 0.5f * (1.0f + tan_h);
+ float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
+ float dg3 = dg2 * 3 * x2mul;
+ return grad * __float2half(dg1 + dg2 + dg3);
+}
+
+template <>
+__device__ float activation_bwd_kernel(float grad,
+ float x) {
+ return x > 0.f ? grad : 0.f;
+}
+
+template <>
+__device__ __half
+activation_bwd_kernel(__half grad, __half x) {
+ const __half half_zero = __float2half(0.f);
+ return x > half_zero ? grad : half_zero;
+}
+
+template <>
+__device__ __half2 activation_bwd_kernel(
+ __half2 grad2, __half2 x_half2) {
+ const __half half_zero = __float2half(0.f);
+ return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
+ x_half2.y > half_zero ? grad2.y : half_zero);
+}
+
+/**
+ * @brief init curand states in global memory
+ *
+ * @thread grid_dim * block*dim to suuport any size of states
+ * @param state persistant curand states
+ * @param seed seed to init states
+ * @return void
+ */
+__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
+ int seed) {
+ /* Each thread gets same seed, a different sequence
+ number, no offset */
+ int id = threadIdx.x + blockIdx.x * blockDim.x;
+ curand_init(seed, id, 0, &state[id]);
+}
+
+void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
+ cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
+ int grid_dim = total_count >> 9;
+ curand_init_kernel<<>>(
+ curandstate, std::chrono::duration_cast(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count());
+}
+
+/**
+ * @brief element-wise dropout, store dropped position in mask, it's not
+ * in-place
+ *
+ * @thread
+ * gridDim.x = total_count / 1024
+ * blockDim.x = 1024
+ *
+ * @param total_count total elements
+ * @param ratio drop ratio
+ * @param out any size of float and __half
+ * @param in same with out
+ * @param mask uint8 type, same size with out
+ * @param seed seed to curand
+ * @return void
+ */
+__global__ void ls_dropout_kernel(const int total_count, const float ratio,
+ float *__restrict__ out,
+ const float *__restrict__ in,
+ uint8_t *__restrict__ mask, const int seed) {
+ const float scale = 1.f / (1.f - ratio);
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (i * 4 >= total_count) return;
+
+ curandStatePhilox4_32_10_t state;
+ curand_init(seed, i, 0, &state);
+ uint8_t m[4];
+
+ float4 *out4 = reinterpret_cast(out);
+ const float4 *data4 = reinterpret_cast(in);
+ uint32_t *mask4 = reinterpret_cast(mask);
+ float4 rand = curand_uniform4(&state);
+
+ m[0] = (uint8_t)(rand.x > ratio);
+ m[1] = (uint8_t)(rand.y > ratio);
+ m[2] = (uint8_t)(rand.z > ratio);
+ m[3] = (uint8_t)(rand.w > ratio);
+
+ uint32_t *m4 = reinterpret_cast(m);
+ mask4[i] = m4[0];
+
+ float4 input4 = data4[i];
+ float4 res4;
+ res4.x = input4.x * scale * m[0];
+ res4.y = input4.y * scale * m[1];
+ res4.z = input4.z * scale * m[2];
+ res4.w = input4.w * scale * m[3];
+ out4[i] = res4;
+}
+
+__global__ void ls_dropout_kernel(const int total_count, const float ratio,
+ __half *__restrict__ out,
+ const __half *__restrict__ in,
+ uint8_t *__restrict__ mask, const int seed) {
+ const float scale = 1.f / (1.f - ratio);
+
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (i * 8 >= total_count) return;
+
+ curandStatePhilox4_32_10_t state;
+ curand_init(seed, i, 0, &state);
+
+ const float4 *vals_float4 = reinterpret_cast(in);
+ float4 *outs_float4 = reinterpret_cast(out);
+ uint64_t *mask8 = reinterpret_cast(mask);
+
+ uint8_t m[8];
+ float4 rand = curand_uniform4(&state);
+ m[0] = (uint8_t)(rand.x > ratio);
+ m[1] = (uint8_t)(rand.y > ratio);
+ m[2] = (uint8_t)(rand.z > ratio);
+ m[3] = (uint8_t)(rand.w > ratio);
+ rand = curand_uniform4(&state);
+ m[4] = (uint8_t)(rand.x > ratio);
+ m[5] = (uint8_t)(rand.y > ratio);
+ m[6] = (uint8_t)(rand.z > ratio);
+ m[7] = (uint8_t)(rand.w > ratio);
+ uint64_t *m8 = reinterpret_cast(m);
+ mask8[i] = *m8;
+
+ float4 val_float4 = vals_float4[i];
+ float4 out_float4;
+ __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
+ __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
+ __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
+ __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
+ __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
+ __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
+ out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
+ out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
+ out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
+ out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
+ outs_float4[i] = out_float4;
+}
+
+/**
+ * @brief element-wise dropout backward with dropout mask, it's
+ * not in-place
+ *
+ * @thread
+ * gridDim.x = total_count / 1024
+ * blockDim.x = 1024
+ *
+ * @param total_count total elements
+ * @param ratio drop ratio
+ * @param in any size of float and __half
+ * @param mask uint8 type, same size with in
+ * @return void
+ */
+__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
+ float *out, const float *in,
+ const uint8_t *__restrict__ mask) {
+ const float scale = 1.f / (1.f - ratio);
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (i * 4 >= total_count) return;
+
+ uint8_t m[4];
+
+ float4 *out4 = reinterpret_cast(out);
+ const float4 *in4 = reinterpret_cast(in);
+ const uint32_t *mask4 = reinterpret_cast(mask);
+
+ uint32_t *m4 = reinterpret_cast(m);
+ m4[0] = mask4[i];
+
+ float4 input4 = in4[i];
+ float4 res4;
+ res4.x = input4.x * scale * static_cast(m[0]);
+ res4.y = input4.y * scale * static_cast(m[1]);
+ res4.z = input4.z * scale * static_cast(m[2]);
+ res4.w = input4.w * scale * static_cast(m[3]);
+ out4[i] = res4;
+}
+
+__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
+ __half *out, const __half *in,
+ const uint8_t *__restrict__ mask) {
+ const __half scale = 1.f / (1.f - ratio);
+
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (i * 8 >= total_count) return;
+
+ float4 *out4 = reinterpret_cast(out);
+ const float4 *vals_float4 = reinterpret_cast(in);
+ const uint64_t *mask8 = reinterpret_cast(mask);
+
+ uint8_t m[8];
+ uint64_t *m8 = reinterpret_cast(m);
+ m8[0] = mask8[i];
+
+ float4 val_float4 = vals_float4[i];
+ float4 out_float4;
+ __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
+ __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
+ __half2 scale_mask_1 =
+ __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
+ __half2 scale_mask_2 =
+ __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
+ __half2 scale_mask_3 =
+ __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
+ __half2 scale_mask_4 =
+ __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
+ out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
+ out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
+ out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
+ out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
+ out4[i] = out_float4;
+}
+
+template <>
+void launch_ls_dropout(float *out, const float *vals, uint8_t *mask,
+ int total_count, float ratio, cudaStream_t stream,
+ bool backward) {
+ int grid_dim = total_count >> 12;
+ if (!backward) {
+ ls_dropout_kernel<<>>(
+ total_count, ratio, out, vals, mask,
+ std::chrono::duration_cast(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count());
+ } else {
+ ls_dropout_bwd_kernel<<>>(total_count, ratio,
+ out, vals, mask);
+ }
+}
+
+template <>
+void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
+ int total_count, float ratio,
+ cudaStream_t stream, bool backward) {
+ int grid_dim = total_count >> 13;
+ if (!backward) {
+ ls_dropout_kernel<<>>(
+ total_count, ratio, out, vals, mask,
+ std::chrono::duration_cast(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count());
+ } else {
+ ls_dropout_bwd_kernel<<>>(total_count, ratio,
+ out, vals, mask);
+ }
+}
+
+/**
+ * @brief fused bias, dropout, and residual at the end of Attention and FFN,
+ * store dropped position in mask, it's not in-place
+ *
+ * @thread
+ * gridDim.x = total_count / 1024
+ * blockDim.x = 1024
+ *
+ * @param total_count total elements
+ * @param ratio drop ratio
+ * @param out [batch_size, seq_len, hidden_size], float and __half
+ * @param in [batch_size, seq_len, hidden_size], float and __half
+ * @param mask [batch_size, seq_len, hidden_size], uint8 type
+ * @param bias [hidden_size], ffn bias
+ * @param residual [batch_size, seq_len, hidden_size], float and __half
+ * @param seed seed to curand
+ * @param hidden_size hidden size
+ * @return void
+ */
+__global__ void ls_dropout_res_bias_kernel(
+ const int total_count, const float ratio, float *__restrict__ out,
+ const float *__restrict__ in, uint8_t *__restrict__ mask,
+ const float *__restrict__ bias, const float *__restrict__ residual,
+ const int seed, const int hidden_size) {
+ const float scale = 1.f / (1.f - ratio);
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (i * 4 >= total_count) return;
+
+ curandStatePhilox4_32_10_t state;
+ curand_init(seed, i, 0, &state);
+ uint8_t m[4];
+
+ float4 *out4 = reinterpret_cast(out);
+ const float4 *data4 = reinterpret_cast(in);
+ const float4 *residual4 = reinterpret_cast(residual);
+ const float4 *bias4 = reinterpret_cast(bias);
+ uint32_t *mask4 = reinterpret_cast(mask);
+ float4 rand = curand_uniform4(&state);
+
+ m[0] = static_cast(rand.x > ratio);
+ m[1] = static_cast(rand.y > ratio);
+ m[2] = static_cast(rand.z > ratio);
+ m[3] = static_cast(rand.w > ratio);
+
+ int bias_i = i % (hidden_size >> 2);
+ uint32_t *m4 = reinterpret_cast(m);
+ mask4[i] = m4[0];
+ const float4 input4 = data4[i];
+ const float4 b4 = __ldg(&bias4[bias_i]);
+ const float4 res4 = residual4[i];
+ float4 output4;
+
+ output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
+ output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
+ output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
+ output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
+
+ out4[i] = output4;
+}
+
+__global__ void ls_dropout_res_bias_kernel(
+ const int total_count, const float ratio, __half *__restrict__ out,
+ const __half *__restrict__ in, uint8_t *__restrict__ mask,
+ const __half *__restrict__ bias, const __half *__restrict__ residual,
+ const int seed, const int hidden_size) {
+ const __half scale = 1. / (1. - ratio);
+
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (i * 8 >= total_count) return;
+
+ curandStatePhilox4_32_10_t state;
+ curand_init(seed, i, 0, &state);
+
+ const float4 *vals_float4 = reinterpret_cast(in);
+ float4 *outs_float4 = reinterpret_cast(out);
+ const float4 *residual4 = reinterpret_cast(residual);
+ const float4 *bias4 = reinterpret_cast(bias);
+ uint64_t *mask8 = reinterpret_cast(mask);
+
+ uint8_t m[8];
+ float4 rand = curand_uniform4(&state);
+ m[0] = static_cast(rand.x > ratio);
+ m[1] = static_cast(rand.y > ratio);
+ m[2] = static_cast(rand.z > ratio);
+ m[3] = static_cast(rand.w > ratio);
+ rand = curand_uniform4(&state);
+ m[4] = static_cast(rand.x > ratio);
+ m[5] = static_cast(rand.y > ratio);
+ m[6] = static_cast