From 9a4842c571cd63e6a660182a234bc6ff60991ba0 Mon Sep 17 00:00:00 2001
From: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Date: Mon, 17 Jul 2023 17:30:57 +0800
Subject: [PATCH 1/6] revise shardformer readme (#4246)
---
colossalai/shardformer/README.md | 25 ++++++++++---------------
1 file changed, 10 insertions(+), 15 deletions(-)
diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index 6ae32e4fbd42..bf4215c52980 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -22,7 +22,6 @@
- [System Performance](#system-performance)
- [Convergence](#convergence)
-
## 🔗 Introduction
**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
@@ -33,7 +32,7 @@
The sample API usage is given below:
-``` python
+```python
from colossalai.shardformer import ShardConfig, Shard
from transformers import BertForMaskedLM
@@ -74,6 +73,7 @@ shard_former.optimize(model, my_policy)
```
+
## 🗺 Roadmap
We will follow this roadmap to develop Shardformer:
@@ -117,15 +117,13 @@ Please refer to the code for more details.
-
-
### Distributed Modules
`ShardFormer` replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation.
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
-```python
+````python
class ParallelModule(torch.nn.Module):
@abstractmethod
@@ -140,7 +138,7 @@ class ParallelModule(torch.nn.Module):
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
```
"""
-```
+````
### Shard Config
@@ -169,7 +167,7 @@ We abstract the policy into four stages:
2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
-``` python
+```python
@dataclass
class ModulePolicyDescription:
r"""
@@ -238,7 +236,6 @@ class Policy(ABC):
...
```
-
### Model Sharder
`ModelSharder` is the class in charge of sharding the model based on the given policy.
@@ -324,21 +321,20 @@ You can create a new file in the `colossalai/shardformer/policies` folder and na
Please follow the following protocols when writing your policy:
- You have to make a clear decision what you want to replace exactly in the original PyTorch module
- - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
- - Use `ModulePolicyDescription.param_replacement` to replace the module parameters
- - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the .
- - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**.
+ - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
+ - Use `ModulePolicyDescription.param_replacement` to replace the module parameters
+ - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the replacement.
+ - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**.
- You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement.
- `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`.
- **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module.
- Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy.
-
- Step 2. Register your policy to the autopolicy
Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.
-For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
+For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.\_\_class\_\_.\_\_qualname\_\_). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
```python
_POLICY_LIST = {
@@ -360,7 +356,6 @@ Add your model to the `tests/kit/model_zoo` file. This allows you to define test
Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.
-
- Step 3. Execute your test
When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.
From 7ff11b5537123b50d8b1b3b0fbaca0fa31d9481b Mon Sep 17 00:00:00 2001
From: binmakeswell
Date: Mon, 17 Jul 2023 21:07:44 +0800
Subject: [PATCH 2/6] [example] add llama pretraining (#4257)
---
README.md | 11 +++++++++++
docs/README-zh-Hans.md | 10 ++++++++++
examples/language/llama/README.md | 11 +++++++++++
3 files changed, 32 insertions(+)
create mode 100644 examples/language/llama/README.md
diff --git a/README.md b/README.md
index 34c8a6b730a3..21670e1e59fb 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
## Latest News
+* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
@@ -49,6 +50,7 @@
Parallel Training Demo
+ - LLaMA
- GPT-3
- GPT-2
- BERT
@@ -216,6 +218,15 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
## Parallel Training Demo
+### LLaMA
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
+
### GPT-3
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 1dde7a816676..e229c65d890c 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
@@ -49,6 +50,7 @@
-
并行训练样例展示
+ - LLaMA
- GPT-3
- GPT-2
- BERT
@@ -209,6 +211,14 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
(返回顶端)
## 并行训练样例展示
+### LLaMA
+
+
+
+
+- 650亿参数大模型预训练加速38%
+[[代码]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[博客]](https://www.hpc-ai.tech/blog/large-model-pretraining)
### GPT-3
diff --git a/examples/language/llama/README.md b/examples/language/llama/README.md
new file mode 100644
index 000000000000..871804f2ca86
--- /dev/null
+++ b/examples/language/llama/README.md
@@ -0,0 +1,11 @@
+# Pretraining LLaMA: best practices for building LLaMA-like base models
+
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
+
+> Since the main branch is being updated, in order to maintain the stability of the code, this example is temporarily kept as an [independent branch](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama).
From 4b977541a86c90946badc77a6a77fee64fdc8cce Mon Sep 17 00:00:00 2001
From: Cuiqing Li
Date: Tue, 18 Jul 2023 23:53:38 +0800
Subject: [PATCH 3/6] [Kernels] added triton-implemented of self attention for
colossal-ai (#4241)
* added softmax kernel
* added qkv_kernel
* added ops
* adding tests
* upload tets
* fix tests
* debugging
* debugging tests
* debugging
* added
* fixed errors
* added softmax kernel
* clean codes
* added tests
* update tests
* update tests
* added attention
* add
* fixed pytest checking
* add cuda check
* fix cuda version
* fix typo
---
colossalai/kernel/triton/ops.py | 209 ++++++++++++++++++
colossalai/kernel/triton/qkv_matmul_kernel.py | 109 +++++++++
colossalai/kernel/triton/softmax_kernel.py | 44 ++++
tests/test_kernels/test_self_attention.py | 136 ++++++++++++
tests/test_kernels/test_softmax.py | 27 +++
5 files changed, 525 insertions(+)
create mode 100644 colossalai/kernel/triton/ops.py
create mode 100644 colossalai/kernel/triton/qkv_matmul_kernel.py
create mode 100644 colossalai/kernel/triton/softmax_kernel.py
create mode 100644 tests/test_kernels/test_self_attention.py
create mode 100644 tests/test_kernels/test_softmax.py
diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py
new file mode 100644
index 000000000000..5e8d4ba3ec99
--- /dev/null
+++ b/colossalai/kernel/triton/ops.py
@@ -0,0 +1,209 @@
+import torch
+from torch import nn
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ from .qkv_matmul_kernel import qkv_gemm_4d_kernel
+ from .softmax_kernel import softmax_kernel
+
+ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
+ r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
+ Args:
+ q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
+ k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
+ v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
+ input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len)
+ scale: the float scale value which is used to multiply with Q*K^T before doing softmax
+
+ Return:
+ output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size)
+ """
+ assert len(q.shape) == 4, "the shape of q val must be 4"
+ batches, M, H, K = q.shape
+ assert q.shape == k.shape, "the shape of q and the shape of k must be equal"
+ assert q.shape == v.shape, "the shape of q and the shape of v must be equal"
+ assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal"
+
+ N = k.shape[1]
+
+ # head_size * num_of_head
+ d_model = q.shape[-1] * q.shape[-2]
+
+ score_output = torch.empty(
+ (batches, H, M, N), device=q.device, dtype=q.dtype)
+
+ grid = lambda meta: (
+ batches,
+ H,
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
+ triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ )
+
+ qkv_gemm_4d_kernel[grid](
+ q, k, score_output,
+ M, N, K,
+ q.stride(0), q.stride(2), q.stride(1), q.stride(3),
+ k.stride(0), k.stride(2), k.stride(3), k.stride(1),
+ score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
+ scale=scale,
+ # currently manually setting, later on we can use auto-tune config to match best setting
+ BLOCK_SIZE_M=64,
+ BLOCK_SIZE_N=32,
+ BLOCK_SIZE_K=32,
+ GROUP_SIZE_M=8,
+ )
+
+ softmax_output = torch.empty(
+ score_output.shape, device=score_output.device, dtype=score_output.dtype)
+ score_output_shape = score_output.shape
+
+ score_output = score_output.view(-1, score_output.shape[-1])
+ n_rows, n_cols = score_output.shape
+
+ if n_rows <= 350000:
+
+ block_size = max(triton.next_power_of_2(n_cols), 2)
+ num_warps = 4
+ if block_size >= 4096:
+ num_warps = 16
+ elif block_size >= 2048:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ softmax_kernel[(n_rows, )](
+ softmax_output,
+ score_output,
+ score_output.stride(0),
+ n_cols,
+ mask_ptr = input_mask,
+ num_warps=num_warps,
+ BLOCK_SIZE=block_size,
+ )
+
+ else:
+ #TODO: change softmax kernel functions to make it suitable for large size dimension
+ softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
+ softmax_output = softmax_output.view(*score_output_shape)
+
+ batches, H, M, K = softmax_output.shape
+ N = v.shape[-1]
+
+ output = torch.empty(
+ (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
+
+ grid = lambda meta: (
+ batches,
+ H,
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
+ triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ )
+
+ qkv_gemm_4d_kernel[grid](
+ softmax_output, v, output,
+ M, N, K,
+ softmax_output.stride(0),
+ softmax_output.stride(1),
+ softmax_output.stride(2),
+ softmax_output.stride(3),
+ v.stride(0),
+ v.stride(2),
+ v.stride(1),
+ v.stride(3),
+ output.stride(0),
+ output.stride(2),
+ output.stride(1),
+ output.stride(3),
+ BLOCK_SIZE_M=128,
+ BLOCK_SIZE_N=64,
+ BLOCK_SIZE_K=64,
+ GROUP_SIZE_M=8,
+ scale=-1,
+ )
+ return output.view(batches, -1, d_model)
+
+
+ def self_attention_compute_using_triton(qkv,
+ input_mask,
+ layer_past,
+ alibi,
+ scale,
+ head_size,
+ triangular=False,
+ use_flash=False):
+
+ assert qkv.is_contiguous()
+ assert alibi is None, "current triton self-attention does not support alibi"
+ batches = qkv.shape[0]
+ d_model = qkv.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ q = qkv[:, :, :d_model]
+ k = qkv[:, :, d_model:d_model * 2]
+ v = qkv[:, :, d_model * 2:]
+ q = q.view(batches, -1, num_of_heads, head_size)
+ k = k.view(batches, -1, num_of_heads, head_size)
+ v = v.view(batches, -1, num_of_heads, head_size)
+
+ data_output_triton = self_attention_forward_without_fusion(
+ q, k, v, input_mask, scale)
+
+ return data_output_triton
+
+
+ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
+ if mask is not None:
+ assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
+ assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
+
+ hidden_dim = input.shape[-1]
+ output = torch.empty_like(input)
+ input = input.view(-1, hidden_dim)
+ if mask is not None:
+ mask = mask.view(-1, hidden_dim)
+ assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
+
+ num_rows, num_cols = input.shape
+ block_size = max(triton.next_power_of_2(num_cols), 2)
+ num_warps = 16
+ if block_size >= 4096:
+ num_warps = 16
+ elif block_size >= 2048:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ if num_rows <= 350000:
+ grid = (num_rows,)
+ softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
+ else:
+ grid = lambda meta: ()
+
+ grid = lambda meta: (
+ triton.cdiv(num_rows, meta["BLOCK_M"]),
+ )
+
+ BLOCK_M = 32
+ if block_size >= 4096:
+ BLOCK_M = 4
+ elif block_size >= 2048:
+ BLOCK_M = 8
+
+ softmax_kernel_2[grid](output_ptr = output,
+ input_ptr = input,
+ row_stride = input.stride(0),
+ n_rows = num_rows,
+ n_cols = num_cols,
+ mask_ptr = mask,
+ # currently manually setting up size
+ BLOCK_M = 32,
+ BLOCK_SIZE = block_size)
+
+ return output
\ No newline at end of file
diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py
new file mode 100644
index 000000000000..62fc6bba0360
--- /dev/null
+++ b/colossalai/kernel/triton/qkv_matmul_kernel.py
@@ -0,0 +1,109 @@
+import torch
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+
+if HAS_TRITON:
+ '''
+ this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
+ '''
+ @triton.jit
+ def qkv_gemm_4d_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ M,
+ N,
+ K,
+ stride_ab,
+ stride_ah,
+ stride_am,
+ stride_ak,
+ stride_bb,
+ stride_bh,
+ stride_bk,
+ stride_bn,
+ stride_cb,
+ stride_ch,
+ stride_cm,
+ stride_cn,
+ scale,
+ # Meta-parameters
+ BLOCK_SIZE_M : tl.constexpr = 64,
+ BLOCK_SIZE_N : tl.constexpr = 32,
+ BLOCK_SIZE_K : tl.constexpr = 32,
+ GROUP_SIZE_M : tl.constexpr = 8,
+ ):
+ r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
+ where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
+ Args:
+ a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
+ b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K)
+ c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N)
+ stride_ab(tl.constexpr): stride for bs-dimention for tensor array A
+ stride_ah(tl.constexpr): stride for h-dimention for tensor array A
+ stride_am(tl.constexpr): stride for m-dimention for tensor array A
+ stride_ak(tl.constexpr): stride for k-dimention for tensor array A
+ stride_bb(tl.constexpr): stride for bs-dimention for tensor array B
+ stride_bh(tl.constexpr): stride for h-dimention for tensor array B
+ stride_bk(tl.constexpr): stride for k-dimention for tensor array B
+ stride_bn(tl.constexpr): stride for n-dimention for tensor array B
+ stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
+ stride_ch(tl.constexpr): stride for h-dimention for tensor array output
+ stride_cm(tl.constexpr): stride for m-dimention for tensor array output
+ stride_cn(tl.constexpr): stride for n-dimention for tensor array output
+ BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
+ BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
+ BLOCK_SIZE_K : tiling size for K-dimension of a and b
+ GROUP_SIZE_M : group size for reducing cache miss, more details:
+ """
+
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ batch = tl.program_id(axis = 0)
+ head = tl.program_id(axis = 1)
+ pid = tl.program_id(axis = 2)
+
+ # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
+ (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
+ b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, K, BLOCK_SIZE_K):
+ a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
+ b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
+ a = tl.load(a_ptrs, mask=a_mask, other=0.)
+ b = tl.load(b_ptrs, mask=b_mask, other=0.)
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+
+ accumulator = accumulator.to(c_ptr.dtype.element_ty)
+ if scale > 0:
+ accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
+
+
+ offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
+ stride_cn * offs_accumu_n[None, :])
+ accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
+ tl.store(c_ptrs, accumulator, mask=accumulator_mask)
diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py
new file mode 100644
index 000000000000..c215890badff
--- /dev/null
+++ b/colossalai/kernel/triton/softmax_kernel.py
@@ -0,0 +1,44 @@
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ '''
+ softmax kernel is modified based on
+ https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
+ '''
+ @triton.jit
+ def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
+ r""" the kernel function for implementing softmax operator
+ Args:
+ output_ptr: the output after finishing softmax operation, (N, hidden_dim)
+ input_ptr: the tensor of input, shape should be (N, hidden_dim)
+ n_cols(tl.constexpr): the number of cols of input
+ BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
+ """
+ row_idx = tl.program_id(0)
+ row_start_ptr = input_ptr + row_idx * row_stride
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ input_ptrs = row_start_ptr + col_offsets
+ row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
+ row_minus_max = row - tl.max(row, axis=0)
+
+ if mask_ptr is not None:
+ # load mask into SRAM
+ mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
+ mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
+
+ # update
+ row_minus_max = row_minus_max + mask
+
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+ output_row_start_ptr = output_ptr + row_idx * row_stride
+ output_ptrs = output_row_start_ptr + col_offsets
+ # Write back output to DRAM
+ tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
\ No newline at end of file
diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py
new file mode 100644
index 000000000000..b316404a58db
--- /dev/null
+++ b/tests/test_kernels/test_self_attention.py
@@ -0,0 +1,136 @@
+import pytest
+from packaging import version
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from colossalai.kernel.triton.ops import self_attention_compute_using_triton
+from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+def test_qkv_matmul():
+ qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
+ scale = 1.2
+ head_size = 32
+ batches = qkv.shape[0]
+ d_model = qkv.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ q = qkv[:, :, :d_model]
+ k = qkv[:, :, d_model:d_model * 2]
+
+ q = q.view(batches, -1, num_of_heads, head_size)
+ k = k.view(batches, -1, num_of_heads, head_size)
+ q_copy = q.clone()
+ k_copy = k.clone()
+ q = torch.transpose(q, 1, 2).contiguous()
+ k = torch.transpose(k, 1, 2).contiguous()
+ k = torch.transpose(k, 2, 3).contiguous()
+
+ torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k)
+ torch_ouput *= 1.2
+
+ q, k = q_copy, k_copy
+ batches, M, H, K = q.shape
+ N = k.shape[1]
+ score_output = torch.empty(
+ (batches, H, M, N), device=q.device, dtype=q.dtype)
+
+ grid = lambda meta: (
+ batches,
+ H,
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
+ triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ )
+
+ K = q.shape[3]
+ qkv_gemm_4d_kernel[grid](
+ q, k, score_output,
+ M, N, K,
+ q.stride(0), q.stride(2), q.stride(1), q.stride(3),
+ k.stride(0), k.stride(2), k.stride(3), k.stride(1),
+ score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
+ scale=scale,
+ # currently manually setting, later on we can use auto-tune config to match best setting
+ BLOCK_SIZE_M=64,
+ BLOCK_SIZE_N=32,
+ BLOCK_SIZE_K=32,
+ GROUP_SIZE_M=8,
+ )
+
+ check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
+ assert check is True, "the outputs of triton and torch are not matched"
+
+
+def self_attention_compute_using_torch(qkv,
+ input_mask,
+ scale,
+ head_size
+ ):
+
+ batches = qkv.shape[0]
+ d_model = qkv.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ q = qkv[:, :, :d_model]
+ k = qkv[:, :, d_model:d_model * 2]
+ v = qkv[:, :, d_model * 2:]
+ q = q.view(batches, -1, num_of_heads, head_size)
+ k = k.view(batches, -1, num_of_heads, head_size)
+ v = v.view(batches, -1, num_of_heads, head_size)
+
+ q = torch.transpose(q, 1, 2).contiguous()
+ k = torch.transpose(k, 1, 2).contiguous()
+ v = torch.transpose(v, 1, 2).contiguous()
+
+ k = torch.transpose(k, -1, -2).contiguous()
+
+ score_output = torch.einsum('bnij,bnjk->bnik', q, k)
+ score_output *= scale
+
+ softmax_output = F.softmax(score_output, dim = -1)
+ res = torch.einsum('bnij,bnjk->bnik', softmax_output, v)
+ res = torch.transpose(res, 1, 2)
+ res = res.contiguous()
+
+
+ return res.view(batches, -1, d_model), score_output, softmax_output
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+def test_self_atttention_test():
+
+ qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
+ data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
+ qkv.clone(),
+ input_mask = None,
+ scale = 1.2,
+ head_size = 32
+ )
+
+ data_output_triton = self_attention_compute_using_triton(
+ qkv.clone(),
+ alibi=None,
+ head_size=32,
+ scale=1.2,
+ input_mask=None,
+ layer_past=None,
+ use_flash=False,
+ triangular=True)
+
+ check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
+ assert check is True, "the triton output is not matched with torch output"
+
+
+if __name__ == "__main__":
+ test_qkv_matmul()
+ test_self_atttention_test()
\ No newline at end of file
diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py
new file mode 100644
index 000000000000..843d811d019c
--- /dev/null
+++ b/tests/test_kernels/test_softmax.py
@@ -0,0 +1,27 @@
+import pytest
+from packaging import version
+import torch
+from torch import nn
+
+from colossalai.kernel.triton.ops import softmax
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+def test_softmax_op():
+ data_samples = [
+ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),
+ torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32),
+ torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16)
+ ]
+
+ for data in data_samples:
+ module = nn.Softmax(dim = -1)
+ data_torch_out = module(data)
+ data_triton_out = softmax(data)
+ check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)
+ assert check is True, "softmax outputs from triton and torch are not matched"
+
+
+if __name__ == "__main__":
+ test_softmax_op()
\ No newline at end of file
From fc5cef2c79265e36b585ef22c5e1d7f18be52a4e Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Wed, 19 Jul 2023 16:43:01 +0800
Subject: [PATCH 4/6] [lazy] support init on cuda (#4269)
* [lazy] support init on cuda
* [test] update lazy init test
* [test] fix transformer version
---
colossalai/lazy/lazy_init.py | 28 ++++++++++++++++++++--------
requirements/requirements-test.txt | 2 +-
tests/test_lazy/lazy_init_utils.py | 10 +++++++---
tests/test_lazy/test_models.py | 5 +++--
4 files changed, 31 insertions(+), 14 deletions(-)
diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py
index 8b911407307c..1f5345015bf2 100644
--- a/colossalai/lazy/lazy_init.py
+++ b/colossalai/lazy/lazy_init.py
@@ -1,3 +1,4 @@
+from contextlib import contextmanager
from types import MethodType
from typing import Callable, Dict, Optional, Union
@@ -61,12 +62,15 @@ class _MyTensor(Tensor):
"""
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
+ default_device: Optional[torch.device] = None
+
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
cls._pre_op_fn()
if concrete_data is not None:
# uniform api as LazyTensor
data = concrete_data
else:
+ kwargs['device'] = cls.default_device
data = func(*args, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
@@ -142,6 +146,8 @@ class LazyTensor(torch.Tensor):
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
+ default_device: Optional[torch.device] = None
+
@staticmethod
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
if concrete_data is not None:
@@ -159,6 +165,8 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
return r
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
+ if func.__name__ in _NORMAL_FACTORY:
+ kwargs = {**kwargs, 'device': LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
@@ -206,16 +214,11 @@ def _materialize_data(self) -> torch.Tensor:
if self._materialized_data is None:
# apply factory method
func, args, kwargs = self._factory_method
-
# apply cached sequence
self._pre_op_fn()
- try:
- init_val = func(*tree_map(self._replace_with_materialized, args),
- **tree_map(self._replace_with_materialized, kwargs))
- except TypeError as e:
- print(f'init fn: {func.__name__}')
- raise e
+ init_val = func(*tree_map(self._replace_with_materialized, args),
+ **tree_map(self._replace_with_materialized, kwargs))
self._materialized_data = self._rerun_ops(init_val)
return self._materialized_data
@@ -305,6 +308,7 @@ def wrap(y, i=None):
else:
# out of place op, create new lazy tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
+ fn.__name__ = func.__name__
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
elif type(y) is Tensor:
@@ -435,14 +439,21 @@ class LazyInitContext:
"""
_replaced: bool = False
- def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor):
+ def __init__(self,
+ tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
+ default_device: Optional[Union[torch.device, str, int]] = None):
+ assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self.overrides = {}
self.tensor_cls = tensor_cls
+ self.old_default_device = LazyTensor.default_device
+ self.default_device = default_device
def __enter__(self):
if LazyInitContext._replaced:
raise RuntimeError(f'LazyInitContext is not reentrant')
LazyInitContext._replaced = True
+ self.old_default_device = self.tensor_cls.default_device
+ self.tensor_cls.default_device = self.default_device
def wrap_factory_method(target):
# factory functions (eg. torch.empty())
@@ -518,6 +529,7 @@ def wrapper(*args, **kwargs):
setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb):
+ self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig)
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index 50121a9283f2..9f6580c72d1b 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -4,7 +4,7 @@ pytest
coverage==7.2.3
git+https://github.com/hpcaitech/pytest-testmon
torchvision
-transformers
+transformers==4.30.2
timm
titans
torchaudio
diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py
index 73c3c5422d8a..9d9e9a3a5c76 100644
--- a/tests/test_lazy/lazy_init_utils.py
+++ b/tests/test_lazy/lazy_init_utils.py
@@ -61,14 +61,18 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn:
f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}'
-def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
+def check_lazy_init(entry: TestingEntry,
+ seed: int = 42,
+ verbose: bool = False,
+ check_forward: bool = False,
+ default_device: str = 'cpu') -> None:
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
- ctx = LazyInitContext(tensor_cls=_MyTensor)
+ ctx = LazyInitContext(tensor_cls=_MyTensor, default_device=default_device)
with ctx:
model = model_fn()
- ctx = LazyInitContext()
+ ctx = LazyInitContext(default_device=default_device)
with ctx:
deferred_model = model_fn()
copied_deferred_model = deepcopy(deferred_model)
diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py
index 4b7aeed73a69..e37184125d21 100644
--- a/tests/test_lazy/test_models.py
+++ b/tests/test_lazy/test_models.py
@@ -6,13 +6,14 @@
@pytest.mark.skipif(not SUPPORT_LAZY, reason='requires torch >= 1.12.0')
@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
-def test_torchvision_models_lazy_init(subset):
+@pytest.mark.parametrize('default_device', ['cpu', 'cuda'])
+def test_torchvision_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
continue
- check_lazy_init(entry, verbose=True)
+ check_lazy_init(entry, verbose=True, default_device=default_device)
if __name__ == '__main__':
From c6f6005990b182d7ee34c1fb84762d31ce7d3616 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Fri, 21 Jul 2023 14:39:01 +0800
Subject: [PATCH 5/6] [checkpointio] Sharded Optimizer Checkpoint for Gemini
Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin
* modify test to reduce testing time
* update doc
* fix bug when keep_gatherd is true under GeminiPlugin
---
colossalai/booster/plugin/gemini_plugin.py | 131 +++++++++++++---
.../checkpoint_io/general_checkpoint_io.py | 38 +++--
colossalai/checkpoint_io/utils.py | 38 +++++
colossalai/zero/gemini/gemini_optimizer.py | 140 ++++++++++++++----
docs/source/en/basics/booster_api.md | 5 +-
docs/source/en/basics/booster_checkpoint.md | 2 -
docs/source/en/basics/booster_plugins.md | 2 -
docs/source/zh-Hans/basics/booster_api.md | 5 +-
.../zh-Hans/basics/booster_checkpoint.md | 1 -
docs/source/zh-Hans/basics/booster_plugins.md | 1 -
.../test_gemini_checkpoint_io.py | 4 +-
.../test_gemini_torch_compability.py | 6 +-
12 files changed, 289 insertions(+), 84 deletions(-)
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 6191f271c318..7b6e17337d36 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -1,3 +1,4 @@
+import gc
import logging
import os
import warnings
@@ -12,11 +13,19 @@
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
-from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
+from colossalai.checkpoint_io.utils import (
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ get_shard_filename,
+ load_shard_state_dict,
+ save_state_dict,
+ save_state_dict_shards,
+)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase
@@ -37,7 +46,7 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
- As there is communication when getting state dict, this must be called on all processes.
+ As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
@@ -54,7 +63,7 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
- As there is communication when getting state dict, this must be called on all processes.
+ As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
@@ -76,7 +85,8 @@ def save_sharded_model(self,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
- Save sharded model
+ Save sharded model.
+ As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
@@ -86,28 +96,24 @@ def save_sharded_model(self,
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
- total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
- for idx, shard_pair in enumerate(state_dict_shard):
- if not self.coordinator.is_master():
- continue
- shard = shard_pair[0]
- shard_file = get_shard_filename(weights_name, idx)
- total_size = total_size + shard_pair[1]
- for key in shard.keys():
- index_file.append_weight_map(key, shard_file)
-
- checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
- save_state_dict(shard, checkpoint_file_path, use_safetensors)
- index_file.append_meta_data("total_size", total_size)
+ # Save shards of optimizer states.
+ is_master = self.coordinator.is_master()
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=is_master,
+ use_safetensors=use_safetensors)
# only save the index file on the master rank
if self.coordinator.is_master():
+ index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
def load_sharded_model(self,
model: GeminiDDP,
@@ -115,7 +121,7 @@ def load_sharded_model(self,
strict: bool = False,
use_safetensors: bool = False):
"""
- load shard model, load model from multiple files
+ Load shard model, load model from multiple files.
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
@@ -125,16 +131,93 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
+
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = optimizer.unwrap()
+
+ assert isinstance(optimizer, ZeroOptimizer)
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
Path(checkpoint).mkdir(parents=True, exist_ok=True)
- super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
+
+ # Preparing file paths and index file.
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+
+ # Store the information of param groups to param_group_file.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = optimizer.get_param_groups_for_saving()
+ torch.save(param_groups, group_file_path)
+
+ # States are broken into shards within max_shard_size.
+ state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
+
+ # Save shards of optimizer states.
+ is_master = self.coordinator.is_master()
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=is_master,
+ use_safetensors=False)
+
+ # Wrap up index file. Only save it on master rank.
+ if self.coordinator.is_master():
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ logging.info(f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
- # TODO(Baizhou): To be implemented.
- pass
+
+ if not os.path.isfile(checkpoint_index_file):
+ logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
+
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = optimizer.unwrap()
+
+ assert isinstance(optimizer, ZeroOptimizer)
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+
+ # Load param_groups.
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory.')
+ saved_param_groups = torch.load(param_group_path)
+ optimizer.load_param_groups(saved_param_groups)
+
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ # Load optimizer states from shard files under checkpoint path.
+ # For each file, only load the states managed by current process.
+ for shard_file in checkpoint_files:
+ state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ optimizer.load_param_states(state_dict_shard)
+ del state_dict_shard
+ gc.collect()
+
+ optimizer.optimizer_loading_epilogue()
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save model to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiModel(ModelWrapper):
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index e1d9066948dd..83e4bdcc863b 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -5,6 +5,7 @@
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
+import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
@@ -16,7 +17,6 @@
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
- has_index_file,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
@@ -25,6 +25,7 @@
load_states_into_optimizer,
save_param_groups,
save_state_dict,
+ save_state_dict_shards,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
@@ -122,15 +123,13 @@ def save_sharded_optimizer(
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
- total_size = 0
- for idx, shard_pair in enumerate(sharded_state):
- shard, current_size = shard_pair
- shard_file = get_shard_filename(states_name, idx)
- total_size = total_size + current_size
- for key in shard.keys():
- index_file.append_weight_map(key, shard_file)
- checkpoint_file_path = os.path.join(checkpoint, shard_file)
- save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=True,
+ use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
@@ -172,18 +171,17 @@ def save_sharded_model(self,
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
-
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
- total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
- for idx, shard_pair in enumerate(state_dict_shard):
- shard = shard_pair[0]
- shard_file = get_shard_filename(weights_name, idx)
- total_size = total_size + shard_pair[1]
- for key in shard.keys():
- index_file.append_weight_map(key, shard_file)
- checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
- save_state_dict(shard, checkpoint_file_path, use_safetensors)
+
+ # Save shards of optimizer states.
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=True,
+ use_safetensors=use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 19e28c3f7068..8837776aee4d 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -1,4 +1,5 @@
# coding=utf-8
+import os
import re
from collections import abc as container_abcs
from collections import defaultdict
@@ -103,6 +104,43 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim
+def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
+ checkpoint: str,
+ index_file: "CheckpointIndexFile",
+ base_filename: str,
+ is_master: bool,
+ use_safetensors: bool = False) -> int:
+ '''
+ Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
+ Args:
+ sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
+ checkpoint (str): The path of checkpoint directory as string.
+ index_file (CheckpointIndexFile): The index file object to be updated.
+ base_filename (str): Decides the prefix of filenames of shards.
+ is_master (bool): Whether current rank is master.
+ use_safetensors (bool): Whether to use safetensors to save checkpoint.
+
+ Returns:
+ int: the total size of shards
+ '''
+
+ total_size = 0
+ for idx, shard_pair in enumerate(sharded_state_dict):
+ if not is_master:
+ continue
+ shard, current_size = shard_pair
+ shard_file = get_shard_filename(base_filename, idx)
+ total_size = total_size + current_size
+ for key in shard.keys():
+ index_file.append_weight_map(key, shard_file)
+ checkpoint_file_path = os.path.join(checkpoint, shard_file)
+
+ # Only save on master rank.
+ save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
+
+ return total_size
+
+
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 99aff6f1c527..7d0db6b1fa23 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -3,7 +3,7 @@
import gc
import math
import warnings
-from typing import Any, Dict, Set, Tuple
+from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
import torch
import torch.distributed as dist
@@ -11,8 +11,10 @@
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
+from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
+from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
@@ -360,10 +362,12 @@ def get_offsets(self, param_id: int) -> tuple:
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
chunk_offset = begin_in_chunk
- shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
+ if chunk.keep_gathered:
+ shard_offset = 0
+ else:
+ shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
shard_size = end_in_chunk - begin_in_chunk
assert chunk_offset >= 0 and shard_offset >= 0
-
return chunk_offset, shard_offset, shard_size
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
@@ -427,7 +431,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
dtype=torch.float32,
requires_grad=False).cpu()
else:
- collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
+ state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
+ collected_states[state_name] = torch.reshape(state_tensor, param.shape)
return collected_states
# Check whether the param with given id is managed by current process.
@@ -536,6 +541,31 @@ def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_s
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
next_state_offset += shard_size
+ def get_param_groups_for_saving(self) -> list:
+ '''
+ Return the param_groups in Pytorch format when saving to checkpoint.
+ '''
+
+ param_groups = copy.deepcopy(self.param_groups_backup)
+
+ # To be compatible with pytorch checkpointing,
+ # store extra hyperparameters used by pytorch Adam optimizer.
+ torch_special_hyperparameters = {
+ 'amsgrad': False,
+ 'maximize': False,
+ 'foreach': None,
+ 'capturable': False,
+ 'differentiable': False,
+ 'fused': False
+ }
+
+ for group in param_groups:
+ for k, v in torch_special_hyperparameters.items():
+ if k not in group:
+ group[k] = v
+
+ return param_groups
+
def state_dict(self, only_rank_0: bool = True) -> dict:
"""
Args:
@@ -555,21 +585,7 @@ def state_dict(self, only_rank_0: bool = True) -> dict:
so it should be called only when memory resources are abundant.
"""
state_dict = {}
- state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
-
- torch_special_hyperparameters = {
- 'amsgrad': False,
- 'maximize': False,
- 'foreach': None,
- 'capturable': False,
- 'differentiable': False,
- 'fused': False
- }
-
- for group in state_dict['param_groups']:
- for k, v in torch_special_hyperparameters.items():
- if k not in group:
- group[k] = v
+ state_dict['param_groups'] = self.get_param_groups_for_saving()
# Collect optimizer states.
state_dict['state'] = dict()
@@ -634,8 +650,24 @@ def cast(param, state_range, value, key=None):
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
+ def load_param_states(self, param_states: dict):
+ """Loads param states from a state_dict. The param_states can be complete or sharded.
+ During loading, filter out the part of states not considered by current process.
+
+ Args:
+ param_states (dict): A mapping from param_id to its states.
+ """
+ for param_id, states in param_states.items():
+ if param_id in self.id_to_fake_params:
+ self.load_single_param_states(param_id, states)
+
+ def optimizer_loading_epilogue(self):
+ # Epilogue when loading state_dict to pytorch optimizer.
+ self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
+ self.optim.defaults.setdefault('differentiable', False)
+
def load_state_dict(self, state_dict: dict):
- """Loads optimizer state from whole optimizer state_dict.
+ """Loads optimizer state from complete optimizer state_dict.
During loading, filter out the part of states not considered by current process.
Args:
@@ -643,17 +675,71 @@ def load_state_dict(self, state_dict: dict):
from a call to :meth:`state_dict`.
"""
assert 'param_groups' in state_dict
+ assert 'state' in state_dict
self.load_param_groups(state_dict['param_groups'])
+ self.load_param_states(state_dict['state'])
+ self.optimizer_loading_epilogue()
- state = state_dict['state']
+ def state_shard(self,
+ prefix: str = '',
+ max_shard_size: int = 1024,
+ only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
+ """Returns dictionaries containing shards of optimizer states one by one.
+ The max size of each dictionary shard is specified by ``max_shard_size``.
- for param_id, param_states in state.items():
- if param_id in self.id_to_fake_params:
- self.load_single_param_states(param_id, param_states)
+ Args:
+ prefix (str, optional): the prefix for states. Default to ''.
+ max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
+ only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
+ only on rank 0, dafault to True.
- # Epilogue for pytorch optimizer.
- self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
- self.optim.defaults.setdefault('differentiable', False)
+ Yields:
+ Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
+ """
+
+ current_block = {}
+ current_block_size = 0
+
+ for param_id in self.id_to_real_params.keys():
+
+ dist.barrier()
+ state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
+
+ ret_block = None
+ ret_block_size = 0
+
+ # A state might contain more than one tensors.
+ # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
+ state_size = 0
+ isDTensor = False
+ for state_tensor in state.values():
+
+ # When state_tensor is not of Tensor class,
+ # e.g., a SGD optimizer with momentum set to 0 can have None as state
+ # The calculation of tensor size should be skipped to avoid error.
+ if not isinstance(state_tensor, torch.Tensor):
+ continue
+
+ # If the states are stored as DTensors, mark isDTensor as true.
+ if is_distributed_tensor(state_tensor):
+ isDTensor = True
+ state_size += calculate_tensor_size(state_tensor)
+
+ if not isDTensor:
+
+ if current_block_size + state_size > max_shard_size and current_block_size > 0:
+ ret_block = current_block
+ ret_block_size = current_block_size
+ current_block = {}
+ current_block_size = 0
+
+ current_block[param_id] = state
+ current_block_size += state_size
+
+ if ret_block != None:
+ yield ret_block, ret_block_size
+
+ yield current_block, current_block_size
class GeminiAdamOptimizer(ZeroOptimizer):
diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md
index 22d5ee818019..1e75c343c14f 100644
--- a/docs/source/en/basics/booster_api.md
+++ b/docs/source/en/basics/booster_api.md
@@ -21,10 +21,13 @@ Plugin is an important component that manages parallel configuration (eg: The ge
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
-**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
+**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
+
+**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
+
### API of booster
{{ autodoc:colossalai.booster.Booster }}
diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md
index adc0af60b7de..b2840fe87441 100644
--- a/docs/source/en/basics/booster_checkpoint.md
+++ b/docs/source/en/basics/booster_checkpoint.md
@@ -21,8 +21,6 @@ Model must be boosted by `colossalai.booster.Booster` before loading. It will de
## Optimizer Checkpoint
-> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet.
-
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
Optimizer must be boosted by `colossalai.booster.Booster` before saving.
diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md
index 5e2586b836ad..c5c45abce8f7 100644
--- a/docs/source/en/basics/booster_plugins.md
+++ b/docs/source/en/basics/booster_plugins.md
@@ -51,8 +51,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
-> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
-
### Torch DDP Plugin
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md
index 1df821ce7d6e..b2235b73bca1 100644
--- a/docs/source/zh-Hans/basics/booster_api.md
+++ b/docs/source/zh-Hans/basics/booster_api.md
@@ -24,10 +24,13 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
**_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
-**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了 DDP 加速方案,实现了模型级别的数据并行,可以跨多机运行。
+**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
**_LowLevelZeroPlugin:_** LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。
+**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。
+
+
### Booster 接口
diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md
index d75f18c908ba..4ed049dcf44f 100644
--- a/docs/source/zh-Hans/basics/booster_checkpoint.md
+++ b/docs/source/zh-Hans/basics/booster_checkpoint.md
@@ -21,7 +21,6 @@
## 优化器 Checkpoint
-> ⚠ 尚不支持以分片方式保存优化器 Checkpoint。
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md
index 5bd88b679000..0f355c43901c 100644
--- a/docs/source/zh-Hans/basics/booster_plugins.md
+++ b/docs/source/zh-Hans/basics/booster_plugins.md
@@ -51,7 +51,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
-> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。
### Torch DDP 插件
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index 0235ff2e2c81..7b664419b405 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -52,7 +52,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
-@parameterize('shard', [False])
+@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
@@ -117,7 +117,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [1, 2])
+@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py
index b34e3e3a1310..464fccb39103 100644
--- a/tests/test_checkpoint_io/test_gemini_torch_compability.py
+++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py
@@ -19,7 +19,7 @@
@clear_cache_before_run()
-@parameterize('shard', [False])
+@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
@@ -83,7 +83,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
@clear_cache_before_run()
-@parameterize('shard', [False])
+@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_gemini_load_from_torch(shard: bool, model_name: str):
@@ -165,7 +165,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [1, 2])
+@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
From 02192a632e6c6f965d93ec79937f97e10e121307 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Fri, 21 Jul 2023 18:36:35 +0800
Subject: [PATCH 6/6] [ci] support testmon core pkg change detection (#4305)
---
.github/workflows/build_on_pr.yml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index 380c8e9f882c..8a1bc8e113de 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -213,6 +213,7 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
- name: Store Testmon Cache
run: |