- - LLaMA
+ - LLaMA 1/2
- GPT-3
- GPT-2
- BERT
@@ -217,8 +218,16 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
(back to top)
## Parallel Training Demo
+### LLaMA2
+
+
+
+
+- 70 billion parameter LLaMA2 model training accelerated by 195%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
-### LLaMA
+### LLaMA1
@@ -463,7 +472,7 @@ To cite this project, you can use the following BibTeX citation.
}
```
-Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
(back to top)
diff --git a/applications/Chat/README.md b/applications/Chat/README.md
index 5a1187ab503d..59e2c4548365 100644
--- a/applications/Chat/README.md
+++ b/applications/Chat/README.md
@@ -200,7 +200,6 @@ We provide an online inference server and a benchmark. We aim to run inference o
We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference.
Online inference server scripts can help you deploy your own services.
-
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## Coati7B examples
@@ -428,7 +427,7 @@ Thanks so much to all of our amazing contributors!
-- An open-source low cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
+- An open-source low-cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
@@ -469,8 +468,7 @@ Coati is developed by ColossalAI Team:
- [ofey404](https://github.com/ofey404)
- [Wenhao Chen](https://github.com/CWHer)
-The Phd student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
-
+The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
- [Xue Fuzhao](https://github.com/XueFuzhao)
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 636b4e6772cb..2959d3fac81c 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -19,7 +19,7 @@
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
-
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
@@ -71,6 +71,42 @@ def _preprocess(sources: Sequence[str],
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
+def _preprocess_chatglm(sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Preprocess the data by tokenizing.
+ None for attention mask, ChatGLM will calculate attention mask according to input ids
+ """
+
+ labels = []
+ input_ids = []
+ for source, target in zip(sources, targets):
+ source_id = tokenizer.encode(text=source, add_special_tokens=False)
+ target_id = tokenizer.encode(text=target, add_special_tokens=False)
+ input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
+ # truncate
+ sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
+ truncate_length = max(0, len(input_id) - max_length)
+ input_id = input_id[truncate_length: ]
+ if truncate_length == len(source_id) + 1:
+ input_id = sp_token_list + input_id[1: ]
+ elif truncate_length > len(source_id) + 1:
+ input_id = sp_token_list + input_id[2: ]
+
+ context_length = input_id.index(tokenizer.bos_token_id)
+ mask_position = context_length - 1
+ label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
+
+ pad_len = max_length - len(input_id)
+ input_id = input_id + [tokenizer.pad_token_id] * pad_len
+ input_ids.append(input_id)
+ labels.append(label + [IGNORE_INDEX] * pad_len)
+ return torch.tensor(input_ids), torch.tensor(labels), None
+
+
class SFTDataset(Dataset):
"""
Dataset for sft model
@@ -94,18 +130,25 @@ def __init__(self,
data["completion"] + tokenizer.eos_token
for data in tqdm(dataset, disable=not is_rank_0())
]
-
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ else:
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ if self.attention_mask is not None:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx],
+ attention_mask=self.attention_mask[idx])
+ else:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx])
class SupervisedDataset(Dataset):
@@ -137,14 +180,22 @@ def __init__(self,
]
logger.info("Tokenizing inputs... This may take some time...")
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ else:
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ if self.attention_mask is not None:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx],
+ attention_mask=self.attention_mask[idx])
+ else:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx])
diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py
index ff75852576c8..b4646f282f0c 100644
--- a/applications/Chat/coati/experience_maker/base.py
+++ b/applications/Chat/coati/experience_maker/base.py
@@ -10,7 +10,7 @@
@dataclass
class Experience:
"""Experience is a batch of data.
- These data should have the the sequence length and number of actions.
+ These data should have the sequence length and number of actions.
Left padding for sequences is applied.
Shapes of each tensor:
diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py
new file mode 100644
index 000000000000..373f19553fdc
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/__init__.py
@@ -0,0 +1,3 @@
+from .chatglm_actor import ChatGLMActor
+
+__all__ = ['ChatGLMActor']
\ No newline at end of file
diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py
new file mode 100644
index 000000000000..c35d994e9319
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py
@@ -0,0 +1,34 @@
+from typing import Optional
+
+import torch
+from .configuration_chatglm import ChatGLMConfig
+from .modeling_chatglm import ChatGLMForConditionalGeneration
+
+from ..base import Actor
+
+
+class ChatGLMActor(Actor):
+ """
+ ChatGLM Actor model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (ChatGLMConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+
+ do not support lora for now.
+ """
+
+ def __init__(self,
+ pretrained: str = None,
+ config: Optional[ChatGLMConfig] = None,
+ checkpoint: bool = False) -> None:
+ if pretrained is not None:
+ model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
+ elif config is not None:
+ model = ChatGLMForConditionalGeneration(config)
+ else:
+ model = ChatGLMForConditionalGeneration(ChatGLMConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank=0, lora_train_bias='none')
diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
new file mode 100644
index 000000000000..f7717f7e68b6
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
@@ -0,0 +1,446 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
+"""
+"""Tokenization classes for ChatGLM."""
+from typing import List, Optional, Union
+import os
+
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging, PaddingStrategy
+from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
+from typing import Dict
+import sentencepiece as spm
+import numpy as np
+
+logger = logging.get_logger(__name__)
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "THUDM/chatglm-6b": 2048,
+}
+
+
+class TextTokenizer:
+ def __init__(self, model_path):
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.Load(model_path)
+ self.num_tokens = self.sp.vocab_size()
+
+ def encode(self, text):
+ return self.sp.EncodeAsIds(text)
+
+ def decode(self, ids: List[int]):
+ return self.sp.DecodeIds(ids)
+
+ def tokenize(self, text):
+ return self.sp.EncodeAsPieces(text)
+
+ def convert_tokens_to_string(self, tokens):
+ return self.sp.DecodePieces(tokens)
+
+ def convert_tokens_to_ids(self, tokens):
+ return [self.sp.PieceToId(token) for token in tokens]
+
+ def convert_token_to_id(self, token):
+ return self.sp.PieceToId(token)
+
+ def convert_id_to_token(self, idx):
+ return self.sp.IdToPiece(idx)
+
+ def __len__(self):
+ return self.num_tokens
+
+
+class SPTokenizer:
+ def __init__(
+ self,
+ vocab_file,
+ num_image_tokens=20000,
+ max_blank_length=80,
+ byte_fallback=True,
+ ):
+ assert vocab_file is not None
+ self.vocab_file = vocab_file
+ self.num_image_tokens = num_image_tokens
+ self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""]
+ self.max_blank_length = max_blank_length
+ self.byte_fallback = byte_fallback
+ self.text_tokenizer = TextTokenizer(vocab_file)
+
+ def _get_text_tokenizer(self):
+ return self.text_tokenizer
+
+ @staticmethod
+ def get_blank_token(length: int):
+ assert length >= 2
+ return f"<|blank_{length}|>"
+
+ @staticmethod
+ def get_tab_token():
+ return f"<|tab|>"
+
+ @property
+ def num_text_tokens(self):
+ return self.text_tokenizer.num_tokens
+
+ @property
+ def num_tokens(self):
+ return self.num_image_tokens + self.num_text_tokens
+
+ @staticmethod
+ def _encode_whitespaces(text: str, max_len: int = 80):
+ text = text.replace("\t", SPTokenizer.get_tab_token())
+ for i in range(max_len, 1, -1):
+ text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
+ return text
+
+ def _preprocess(self, text: str, linebreak=True, whitespaces=True):
+ if linebreak:
+ text = text.replace("\n", "")
+ if whitespaces:
+ text = self._encode_whitespaces(text, max_len=self.max_blank_length)
+ return text
+
+ def encode(
+ self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
+ ) -> List[int]:
+ """
+ @param text: Text to encode.
+ @param linebreak: Whether to encode newline (\n) in text.
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
+ """
+ text = self._preprocess(text, linebreak, whitespaces)
+ if not add_dummy_prefix:
+ text = "" + text
+ tmp = self._get_text_tokenizer().encode(text)
+ tokens = [x + self.num_image_tokens for x in tmp]
+ return tokens if add_dummy_prefix else tokens[2:]
+
+ def postprocess(self, text):
+ text = text.replace("", "\n")
+ text = text.replace(SPTokenizer.get_tab_token(), "\t")
+ for i in range(2, self.max_blank_length + 1):
+ text = text.replace(self.get_blank_token(i), " " * i)
+ return text
+
+ def decode(self, text_ids: List[int]) -> str:
+ ids = [int(_id) - self.num_image_tokens for _id in text_ids]
+ ids = [_id for _id in ids if _id >= 0]
+ text = self._get_text_tokenizer().decode(ids)
+ text = self.postprocess(text)
+ return text
+
+ def decode_tokens(self, tokens: List[str]) -> str:
+ text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
+ text = self.postprocess(text)
+ return text
+
+ def tokenize(
+ self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
+ ) -> List[str]:
+ """
+ @param text: Text to encode.
+ @param linebreak: Whether to encode newline (\n) in text.
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
+ """
+ text = self._preprocess(text, linebreak, whitespaces)
+ if not add_dummy_prefix:
+ text = "" + text
+ tokens = self._get_text_tokenizer().tokenize(text)
+ return tokens if add_dummy_prefix else tokens[2:]
+
+ def __getitem__(self, x: Union[int, str]):
+ if isinstance(x, int):
+ if x < self.num_image_tokens:
+ return "".format(x)
+ else:
+ return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
+ elif isinstance(x, str):
+ if x.startswith("") and x[7:-1].isdigit():
+ return int(x[7:-1])
+ else:
+ return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
+ else:
+ raise ValueError("The key should be str or int.")
+
+
+class ChatGLMTokenizer(PreTrainedTokenizer):
+ """
+ Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = {"vocab_file": "ice_text.model"}
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=False,
+ remove_space=False,
+ bos_token='',
+ eos_token='',
+ end_token='',
+ mask_token='[MASK]',
+ gmask_token='[gMASK]',
+ padding_side="left",
+ pad_token="",
+ unk_token="",
+ num_image_tokens=20000,
+ **kwargs
+ ) -> None:
+ super().__init__(
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ padding_side=padding_side,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ end_token=end_token,
+ mask_token=mask_token,
+ gmask_token=gmask_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ num_image_tokens=num_image_tokens,
+ **kwargs
+ )
+
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.vocab_file = vocab_file
+
+ self.bos_token = bos_token
+ self.eos_token = eos_token
+ self.end_token = end_token
+ self.mask_token = mask_token
+ self.gmask_token = gmask_token
+
+ self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens)
+
+ """ Initialisation """
+
+ @property
+ def gmask_token_id(self) -> Optional[int]:
+ if self.gmask_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.gmask_token)
+
+ @property
+ def end_token_id(self) -> Optional[int]:
+ """
+ `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
+ set.
+ """
+ if self.end_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.end_token)
+
+ @property
+ def vocab_size(self):
+ """ Returns vocab size """
+ return self.sp_tokenizer.num_tokens
+
+ def get_vocab(self):
+ """ Returns vocab as a dict """
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def preprocess_text(self, inputs):
+ if self.remove_space:
+ outputs = " ".join(inputs.strip().split())
+ else:
+ outputs = inputs
+
+ if self.do_lower_case:
+ outputs = outputs.lower()
+
+ return outputs
+
+ def _tokenize(self, text, **kwargs):
+ """ Returns a tokenized string. """
+ text = self.preprocess_text(text)
+
+ seq = self.sp_tokenizer.tokenize(text)
+
+ return seq
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ return self.sp_tokenizer.decode_tokens(tokens)
+
+ def _decode(
+ self,
+ token_ids: Union[int, List[int]],
+ **kwargs
+ ) -> str:
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ if len(token_ids) == 0:
+ return ""
+ if self.pad_token_id in token_ids: # remove pad
+ token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
+ return super()._decode(token_ids, **kwargs)
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.sp_tokenizer[token]
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.sp_tokenizer[index]
+
+ def save_vocabulary(self, save_directory, filename_prefix=None):
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ filename_prefix (`str`, *optional*):
+ An optional prefix to add to the named of the saved files.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, self.vocab_files_names["vocab_file"]
+ )
+ else:
+ vocab_file = save_directory
+
+ with open(self.vocab_file, 'rb') as fin:
+ proto_str = fin.read()
+
+ with open(vocab_file, "wb") as writer:
+ writer.write(proto_str)
+
+ return (vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ gmask_id = self.sp_tokenizer[self.gmask_token]
+ eos_id = self.sp_tokenizer[self.eos_token]
+ token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
+ if token_ids_1 is not None:
+ token_ids_0 = token_ids_0 + token_ids_1
+ return token_ids_0
+
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ bos_token_id = self.sp_tokenizer[self.bos_token]
+ mask_token_id = self.sp_tokenizer[self.mask_token]
+ gmask_token_id = self.sp_tokenizer[self.gmask_token]
+ assert self.padding_side == "left"
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ seq_length = len(required_input)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if max_length is not None:
+ if "attention_mask" not in encoded_inputs:
+ if bos_token_id in required_input:
+ context_length = required_input.index(bos_token_id)
+ else:
+ context_length = seq_length
+ attention_mask = np.ones((1, seq_length, seq_length))
+ attention_mask = np.tril(attention_mask)
+ attention_mask[:, :, :context_length] = 1
+ attention_mask = np.bool_(attention_mask < 0.5)
+ encoded_inputs["attention_mask"] = attention_mask
+
+ if "position_ids" not in encoded_inputs:
+ if bos_token_id in required_input:
+ context_length = required_input.index(bos_token_id)
+ else:
+ context_length = seq_length
+ position_ids = np.arange(seq_length, dtype=np.int64)
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
+ if mask_token in required_input:
+ mask_position = required_input.index(mask_token)
+ position_ids[context_length:] = mask_position
+ block_position_ids = np.concatenate(
+ [np.zeros(context_length, dtype=np.int64),
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
+ encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+
+ if "attention_mask" in encoded_inputs:
+ encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
+ mode='constant', constant_values=True)
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ if "position_ids" in encoded_inputs:
+ encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
+ pad_width=[(0, 0), (difference, 0)])
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+
+ return encoded_inputs
\ No newline at end of file
diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
new file mode 100644
index 000000000000..d0e3f6cc63d7
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
@@ -0,0 +1,107 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/configuration_chatglm.py
+"""
+
+""" ChatGLM model configuration """
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class ChatGLMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`~ChatGLMModel`].
+ It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used
+ to control the model outputs. Read the documentation from [`PretrainedConfig`]
+ for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 150528):
+ Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~ChatGLMModel`] or
+ [`~TFChatGLMModel`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ inner_hidden_size (`int`, *optional*, defaults to 16384):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with.
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
+ layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models).
+ Example:
+
+ ```python
+ >>> from configuration_chatglm import ChatGLMConfig
+ >>> from modeling_chatglm import ChatGLMModel
+
+ >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
+ >>> configuration = ChatGLMConfig()
+
+ >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
+ >>> model = ChatGLMModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+"""
+ model_type = "chatglm"
+
+ def __init__(
+ self,
+ vocab_size=130528,
+ hidden_size=4096,
+ num_layers=28,
+ num_attention_heads=32,
+ layernorm_epsilon=1e-5,
+ use_cache=True,
+ bos_token_id=130004,
+ eos_token_id=130005,
+ mask_token_id=130000,
+ gmask_token_id=130001,
+ pad_token_id=3,
+ max_sequence_length=2048,
+ inner_hidden_size=16384,
+ position_encoding_2d=True,
+ quantization_bit=0,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs
+ ):
+ self.num_layers = num_layers
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.max_sequence_length = max_sequence_length
+ self.layernorm_epsilon = layernorm_epsilon
+ self.inner_hidden_size = inner_hidden_size
+ self.use_cache = use_cache
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.mask_token_id = mask_token_id
+ self.gmask_token_id = gmask_token_id
+ self.position_encoding_2d = position_encoding_2d
+ self.quantization_bit = quantization_bit
+ self.pre_seq_len = pre_seq_len
+ self.prefix_projection = prefix_projection
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs
+ )
\ No newline at end of file
diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
new file mode 100644
index 000000000000..77e7d0d8ea09
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
@@ -0,0 +1,1439 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py
+"""
+
+""" PyTorch ChatGLM model. """
+
+import math
+import copy
+import os
+import warnings
+import re
+import sys
+
+import torch
+import torch.utils.checkpoint
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn.utils import skip_init
+from typing import Optional, Tuple, Union, List, Callable, Dict, Any
+
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ BaseModelOutputWithPastAndCrossAttentions,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
+
+from .configuration_chatglm import ChatGLMConfig
+
+# flags required to enable jit fusion kernels
+
+if sys.platform != 'darwin':
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+ torch._C._jit_override_can_fuse_on_cpu(True)
+ torch._C._jit_override_can_fuse_on_gpu(True)
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
+_CONFIG_FOR_DOC = "ChatGLM6BConfig"
+
+CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "THUDM/chatglm-6b",
+ # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
+]
+
+
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+ scores.zero_()
+ scores[..., 5] = 5e4
+ return scores
+
+
+def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ assert (
+ pointer.shape == array.shape
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class PrefixEncoder(torch.nn.Module):
+ """
+ The torch.nn model to encode the prefix
+ Input shape: (batch-size, prefix-length)
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.prefix_projection = config.prefix_projection
+ if self.prefix_projection:
+ # Use a two-layer MLP to encode the prefix
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
+ self.trans = torch.nn.Sequential(
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
+ torch.nn.Tanh(),
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
+ )
+ else:
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
+
+ def forward(self, prefix: torch.Tensor):
+ if self.prefix_projection:
+ prefix_tokens = self.embedding(prefix)
+ past_key_values = self.trans(prefix_tokens)
+ else:
+ past_key_values = self.embedding(prefix)
+ return past_key_values
+
+
+@torch.jit.script
+def gelu_impl(x):
+ """OpenAI's gelu implementation."""
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
+ (1.0 + 0.044715 * x * x)))
+
+
+def gelu(x):
+ return gelu_impl(x)
+
+
+class RotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ inv_freq = inv_freq.half()
+ self.learnable = learnable
+ if learnable:
+ self.inv_freq = torch.nn.Parameter(inv_freq)
+ self.max_seq_len_cached = None
+ else:
+ self.register_buffer('inv_freq', inv_freq)
+ self.max_seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+ self.precision = precision
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
+ error_msgs):
+ pass
+
+ def forward(self, x, seq_dim=1, seq_len=None):
+ if seq_len is None:
+ seq_len = x.shape[seq_dim]
+ if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
+ self.max_seq_len_cached = None if self.learnable else seq_len
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ if self.precision == torch.bfloat16:
+ emb = emb.float()
+
+ # [sx, 1 (b * np), hn]
+ cos_cached = emb.cos()[:, None, :]
+ sin_cached = emb.sin()[:, None, :]
+ if self.precision == torch.bfloat16:
+ cos_cached = cos_cached.bfloat16()
+ sin_cached = sin_cached.bfloat16()
+ if self.learnable:
+ return cos_cached, sin_cached
+ self.cos_cached, self.sin_cached = cos_cached, sin_cached
+ return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
+
+ def _apply(self, fn):
+ if self.cos_cached is not None:
+ self.cos_cached = fn(self.cos_cached)
+ if self.sin_cached is not None:
+ self.sin_cached = fn(self.sin_cached)
+ return super()._apply(fn)
+
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+
+@torch.jit.script
+def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
+ # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+ return q, k
+
+
+def attention_fn(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ hidden_size_per_partition,
+ layer_id,
+ layer_past=None,
+ scaling_attention_score=True,
+ use_cache=False,
+):
+ if layer_past is not None:
+ past_key, past_value = layer_past[0], layer_past[1]
+ key_layer = torch.cat((past_key, key_layer), dim=0)
+ value_layer = torch.cat((past_value, value_layer), dim=0)
+
+ # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
+ seq_len, b, nh, hidden_size = key_layer.shape
+
+ if use_cache:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ query_key_layer_scaling_coeff = float(layer_id + 1)
+ if scaling_attention_score:
+ query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
+
+ # ===================================
+ # Raw attention scores. [b, np, s, s]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
+
+ matmul_result = torch.zeros(
+ 1, 1, 1,
+ dtype=query_layer.dtype,
+ device=query_layer.device,
+ )
+
+ matmul_result = torch.baddbmm(
+ matmul_result,
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
+ beta=0.0,
+ alpha=1.0,
+ )
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ if self.scale_mask_softmax:
+ self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
+ else:
+ if not (attention_mask == 0).all():
+ # if auto-regressive, skip
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+ dtype = attention_scores.dtype
+ attention_scores = attention_scores.float()
+ attention_scores = attention_scores * query_key_layer_scaling_coeff
+
+ attention_probs = F.softmax(attention_scores, dim=-1)
+
+ attention_probs = attention_probs.type(dtype)
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, present, attention_probs)
+
+ return outputs
+
+
+def default_init(cls, *args, **kwargs):
+ return cls(*args, **kwargs)
+
+
+class SelfAttention(torch.nn.Module):
+ def __init__(self, hidden_size, num_attention_heads,
+ layer_id, hidden_size_per_attention_head=None, bias=True,
+ params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ super(SelfAttention, self).__init__()
+
+ self.layer_id = layer_id
+ self.hidden_size = hidden_size
+ self.hidden_size_per_partition = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_attention_heads_per_partition = num_attention_heads
+ self.position_encoding_2d = position_encoding_2d
+ self.rotary_emb = RotaryEmbedding(
+ self.hidden_size // (self.num_attention_heads * 2)
+ if position_encoding_2d
+ else self.hidden_size // self.num_attention_heads,
+ base=10000,
+ precision=torch.half,
+ learnable=False,
+ )
+
+ self.scale_mask_softmax = None
+
+ if hidden_size_per_attention_head is None:
+ self.hidden_size_per_attention_head = hidden_size // num_attention_heads
+ else:
+ self.hidden_size_per_attention_head = hidden_size_per_attention_head
+
+ self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+
+ # Strided linear layer.
+ self.query_key_value = init_method(
+ torch.nn.Linear,
+ hidden_size,
+ 3 * self.inner_hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ self.dense = init_method(
+ torch.nn.Linear,
+ self.inner_hidden_size,
+ hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ @staticmethod
+ def attention_mask_func(attention_scores, attention_mask):
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+ return attention_scores
+
+ def split_tensor_along_last_dim(self, tensor, num_partitions,
+ contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ attention_mask: [(1, 1), seq_len, seq_len]
+ """
+
+ # [seq_len, batch, 3 * hidden_size]
+ mixed_raw_layer = self.query_key_value(hidden_states)
+
+ # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
+ new_tensor_shape = mixed_raw_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
+
+ # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
+ (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
+
+ if self.position_encoding_2d:
+ q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
+ k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
+ cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
+ position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
+ position_ids[:, 1, :].transpose(0, 1).contiguous()
+ q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
+ q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
+ query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
+ key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
+ else:
+ position_ids = position_ids.transpose(0, 1)
+ cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
+ # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
+ query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
+
+ # [seq_len, batch, hidden_size]
+ context_layer, present, attention_probs = attention_fn(
+ self=self,
+ query_layer=query_layer,
+ key_layer=key_layer,
+ value_layer=value_layer,
+ attention_mask=attention_mask,
+ hidden_size_per_partition=self.hidden_size_per_partition,
+ layer_id=layer_id,
+ layer_past=layer_past,
+ use_cache=use_cache
+ )
+
+ output = self.dense(context_layer)
+
+ outputs = (output, present)
+
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs # output, present, attention_probs
+
+
+class GEGLU(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activation_fn = F.gelu
+
+ def forward(self, x):
+ # dim=-1 breaks in jit for pt<1.10
+ x1, x2 = x.chunk(2, dim=(x.ndim - 1))
+ return x1 * self.activation_fn(x2)
+
+
+class GLU(torch.nn.Module):
+ def __init__(self, hidden_size, inner_hidden_size=None,
+ layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
+ super(GLU, self).__init__()
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ self.layer_id = layer_id
+ self.activation_func = activation_func
+
+ # Project to 4h.
+ self.hidden_size = hidden_size
+ if inner_hidden_size is None:
+ inner_hidden_size = 4 * hidden_size
+ self.inner_hidden_size = inner_hidden_size
+ self.dense_h_to_4h = init_method(
+ torch.nn.Linear,
+ self.hidden_size,
+ self.inner_hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+ # Project back to h.
+ self.dense_4h_to_h = init_method(
+ torch.nn.Linear,
+ self.inner_hidden_size,
+ self.hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ def forward(self, hidden_states):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ """
+
+ # [seq_len, batch, inner_hidden_size]
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+
+ output = self.dense_4h_to_h(intermediate_parallel)
+
+ return output
+
+
+class GLMBlock(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_attention_heads,
+ layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=None,
+ hidden_size_per_attention_head=None,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=torch.float,
+ num_layers=28,
+ position_encoding_2d=True,
+ empty_init=True
+ ):
+ super(GLMBlock, self).__init__()
+ # Set output layer initialization if not provided.
+
+ self.layer_id = layer_id
+
+ # Layernorm on the input data.
+ self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+ self.position_encoding_2d = position_encoding_2d
+
+ # Self attention.
+ self.attention = SelfAttention(
+ hidden_size,
+ num_attention_heads,
+ layer_id,
+ hidden_size_per_attention_head=hidden_size_per_attention_head,
+ bias=use_bias,
+ params_dtype=params_dtype,
+ position_encoding_2d=self.position_encoding_2d,
+ empty_init=empty_init
+ )
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+ self.num_layers = num_layers
+
+ # GLU
+ self.mlp = GLU(
+ hidden_size,
+ inner_hidden_size=inner_hidden_size,
+ bias=use_bias,
+ layer_id=layer_id,
+ params_dtype=params_dtype,
+ empty_init=empty_init
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ attention_mask: [(1, 1), seq_len, seq_len]
+ """
+
+ # Layer norm at the begining of the transformer layer.
+ # [seq_len, batch, hidden_size]
+ attention_input = self.input_layernorm(hidden_states)
+
+ # Self attention.
+ attention_outputs = self.attention(
+ attention_input,
+ position_ids,
+ attention_mask=attention_mask,
+ layer_id=layer_id,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions
+ )
+
+ attention_output = attention_outputs[0]
+
+ outputs = attention_outputs[1:]
+
+ # Residual connection.
+ alpha = (2 * self.num_layers) ** 0.5
+ hidden_states = attention_input * alpha + attention_output
+
+ mlp_input = self.post_attention_layernorm(hidden_states)
+
+ # MLP.
+ mlp_output = self.mlp(mlp_input)
+
+ # Second residual connection.
+ output = mlp_input * alpha + mlp_output
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+
+class ChatGLMPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and
+ a simple interface for downloading and loading pretrained models.
+ """
+
+ is_parallelizable = False
+ supports_gradient_checkpointing = True
+ config_class = ChatGLMConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["GLMBlock"]
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ return
+
+ def get_masks(self, input_ids, device):
+ batch_size, seq_length = input_ids.shape
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
+ attention_mask.tril_()
+ for i, context_length in enumerate(context_lengths):
+ attention_mask[i, :, :context_length] = 1
+ attention_mask.unsqueeze_(1)
+ attention_mask = (attention_mask < 0.5).bool()
+
+ return attention_mask
+
+ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
+ batch_size, seq_length = input_ids.shape
+ if use_gmasks is None:
+ use_gmasks = [False] * batch_size
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
+ if self.position_encoding_2d:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ for i, context_length in enumerate(context_lengths):
+ position_ids[i, context_length:] = mask_positions[i]
+ block_position_ids = [torch.cat((
+ torch.zeros(context_length, dtype=torch.long, device=device),
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
+ )) for context_length in context_lengths]
+ block_position_ids = torch.stack(block_position_ids, dim=0)
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
+ else:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ for i, context_length in enumerate(context_lengths):
+ if not use_gmasks[i]:
+ position_ids[i, context_length:] = mask_positions[i]
+
+ return position_ids
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, ChatGLMModel):
+ module.gradient_checkpointing = value
+
+
+CHATGLM_6B_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
+ usage and behavior.
+
+ Parameters:
+ config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
+ Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CHATGLM_6B_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`ChatGLM6BTokenizer`].
+ See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Selected in the range `[0, config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert *input_ids* indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.",
+ CHATGLM_6B_START_DOCSTRING,
+)
+class ChatGLMModel(ChatGLMPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well
+ as a decoder, in which case a layer of cross-attention is added between
+ the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the
+ `is_decoder` argument of the configuration set to `True`.
+ To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
+ argument and `add_cross_attention` set to `True`; an
+ `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ # recording parameters
+ self.max_sequence_length = config.max_sequence_length
+ self.hidden_size = config.hidden_size
+ self.params_dtype = torch.half
+ self.num_attention_heads = config.num_attention_heads
+ self.vocab_size = config.vocab_size
+ self.num_layers = config.num_layers
+ self.layernorm_epsilon = config.layernorm_epsilon
+ self.inner_hidden_size = config.inner_hidden_size
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
+ self.position_encoding_2d = config.position_encoding_2d
+ self.pre_seq_len = config.pre_seq_len
+ self.prefix_projection = config.prefix_projection
+
+ self.word_embeddings = init_method(
+ torch.nn.Embedding,
+ num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
+ dtype=self.params_dtype
+ )
+ self.gradient_checkpointing = False
+
+ def get_layer(layer_id):
+ return GLMBlock(
+ self.hidden_size,
+ self.num_attention_heads,
+ self.layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=self.inner_hidden_size,
+ hidden_size_per_attention_head=self.hidden_size_per_attention_head,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=self.params_dtype,
+ position_encoding_2d=self.position_encoding_2d,
+ empty_init=empty_init
+ )
+
+ self.layers = torch.nn.ModuleList(
+ [get_layer(layer_id) for layer_id in range(self.num_layers)]
+ )
+
+ # Final layer norm before output.
+ self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
+
+ if self.pre_seq_len is not None:
+ for param in self.parameters():
+ param.requires_grad = False
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+ self.dropout = torch.nn.Dropout(0.1)
+
+ # total_params = sum(p.numel() for p in self.parameters())
+ # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.word_embeddings = new_embeddings
+
+ def get_prompt(self, batch_size, device, dtype=torch.half):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.num_layers * 2,
+ self.num_attention_heads,
+ self.hidden_size // self.num_attention_heads
+ )
+ # seq_len, b, nh, hidden_size
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
+ # past_key_values = [(v[0], v[1]) for v in past_key_values]
+ return past_key_values
+
+ @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if past_key_values is None:
+ if self.pre_seq_len is not None:
+ past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
+ dtype=inputs_embeds.dtype)
+ else:
+ past_key_values = tuple([None] * len(self.layers))
+
+ if attention_mask is None:
+ attention_mask = self.get_masks(
+ input_ids,
+ device=input_ids.device
+ )
+
+
+ if position_ids is None:
+ MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
+ seqs = input_ids.tolist()
+
+ mask_positions, use_gmasks = [], []
+ for seq in seqs:
+ mask_token = gMASK if gMASK in seq else MASK
+ use_gmask = mask_token == gMASK
+ mask_positions.append(seq.index(mask_token))
+ use_gmasks.append(use_gmask)
+
+ position_ids = self.get_position_ids(
+ input_ids,
+ mask_positions=mask_positions,
+ device=input_ids.device,
+ use_gmasks=use_gmasks
+ )
+
+ if self.pre_seq_len is not None and attention_mask is not None:
+ prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
+ attention_mask.device)
+ prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
+
+ # [seq_len, batch, hidden_size]
+ hidden_states = inputs_embeds.transpose(0, 1)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if attention_mask is None:
+ attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ for i, layer in enumerate(self.layers):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ layer_past = past_key_values[i]
+
+ if self.gradient_checkpointing and self.training:
+ layer_ret = torch.utils.checkpoint.checkpoint(
+ layer,
+ hidden_states,
+ position_ids,
+ attention_mask,
+ torch.tensor(i),
+ layer_past,
+ use_cache,
+ output_attentions
+ )
+ else:
+ layer_ret = layer(
+ hidden_states,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ layer_id=torch.tensor(i),
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions
+ )
+
+ hidden_states = layer_ret[0]
+
+ if use_cache:
+ presents = presents + (layer_ret[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
+
+ # Final layer norm.
+ hidden_states = self.final_layernorm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+
+ # self.hidden_size = config.hidden_size
+ # self.params_dtype = torch.half
+ # self.vocab_size = config.vocab_size
+ self.max_sequence_length = config.max_sequence_length
+
+ self.position_encoding_2d = config.position_encoding_2d
+
+ self.transformer = ChatGLMModel(config, empty_init=empty_init)
+
+ self.lm_head = init_method(
+ nn.Linear,
+ config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=torch.half
+ )
+
+ self.config = config
+
+ self.quantized = False
+
+ if self.config.quantization_bit:
+ self.quantize(self.config.quantization_bit, empty_init=True)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ # update past_key_values
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
+ new_attention_mask = attention_mask[:, :, -1:].clone()
+ new_attention_mask[..., -1] = False
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, new_attention_mask], dim=2
+ )
+
+ # update position ids
+ if "position_ids" in model_kwargs:
+ position_ids = model_kwargs["position_ids"]
+ new_position_id = position_ids[..., -1:].clone()
+ new_position_id[:, 1, :] += 1
+ model_kwargs["position_ids"] = torch.cat(
+ [position_ids, new_position_id], dim=-1
+ )
+
+ return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past: Optional[torch.Tensor] = None,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs
+ ) -> dict:
+ batch_size, seq_length = input_ids.shape
+ MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
+ seqs = input_ids.tolist()
+ mask_positions, use_gmasks = [], []
+ for seq in seqs:
+ mask_token = gMASK if gMASK in seq else MASK
+ use_gmask = mask_token == gMASK
+ mask_positions.append(seq.index(mask_token))
+ use_gmasks.append(use_gmask)
+
+ # only last token for input_ids if past is not None
+ if past is not None or past_key_values is not None:
+ last_token = input_ids[:, -1].unsqueeze(-1)
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
+ attention_mask = attention_mask[:, :, -1:]
+ else:
+ attention_mask = None
+ if position_ids is not None:
+ position_ids = position_ids[..., -1:]
+ else:
+ context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
+ if self.position_encoding_2d:
+ position_ids = torch.tensor(
+ [[mask_position, seq_length - context_length] for mask_position, context_length in
+ zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
+ else:
+ position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
+ device=input_ids.device).unsqueeze(-1)
+
+ if past is None:
+ past = past_key_values
+ return {
+ "input_ids": last_token,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask
+ }
+ else:
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
+ logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
+ attention_mask = None
+ if attention_mask is None:
+ attention_mask = self.get_masks(
+ input_ids,
+ device=input_ids.device
+ )
+ if position_ids is None:
+ position_ids = self.get_position_ids(
+ input_ids,
+ device=input_ids.device,
+ mask_positions=mask_positions,
+ use_gmasks=use_gmasks
+ )
+
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask
+ }
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+
+ Output shares the same memory storage as `past`.
+ """
+ return tuple(
+ (
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
+ )
+ for layer_past in past
+ )
+
+ def process_response(self, response):
+ response = response.strip()
+ response = response.replace("[[训练时间]]", "2023年")
+ punkts = [
+ [",", ","],
+ ["!", "!"],
+ [":", ":"],
+ [";", ";"],
+ ["\?", "?"],
+ ]
+ for item in punkts:
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
+ return response
+
+ @torch.no_grad()
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ if not history:
+ prompt = query
+ else:
+ prompt = ""
+ for i, (old_query, response) in enumerate(history):
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ outputs = self.generate(**inputs, **gen_kwargs)
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ response = tokenizer.decode(outputs)
+ response = self.process_response(response)
+ history = history + [(query, response)]
+ return response, history
+
+ @torch.no_grad()
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ if not history:
+ prompt = query
+ else:
+ prompt = ""
+ for i, (old_query, response) in enumerate(history):
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ for outputs in self.stream_generate(**inputs, **gen_kwargs):
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ response = tokenizer.decode(outputs)
+ response = self.process_response(response)
+ new_history = history + [(query, response)]
+ yield response, new_history
+
+ @torch.no_grad()
+ def stream_generate(
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ **kwargs,
+ ):
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
+
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+ if not has_default_max_length:
+ logger.warn(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
+ UserWarning,
+ )
+
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ logits_warper = self._get_logits_warper(generation_config)
+
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ scores = None
+ while True:
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ if generation_config.do_sample:
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(probs, dim=-1)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ break
+ yield input_ids
+
+ def quantize(self, bits: int, empty_init=False, **kwargs):
+ if bits == 0:
+ return
+
+ from .quantization import quantize
+
+ if self.quantized:
+ logger.info("Already quantized.")
+ return self
+
+ self.quantized = True
+
+ self.config.quantization_bit = bits
+
+ self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
+ return self
diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py
index 546f675d7d37..f1597da540a7 100644
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/Chat/coati/models/lora.py
@@ -48,7 +48,7 @@ def __init__(
def reset_parameters(self):
if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
+ # Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py
index 7b9df2ee139b..e04bf5ccb881 100644
--- a/applications/Chat/coati/ray/detached_replay_buffer.py
+++ b/applications/Chat/coati/ray/detached_replay_buffer.py
@@ -16,7 +16,7 @@
class DetachedReplayBuffer:
'''
Detached replay buffer. Share Experience across workers on the same node.
- Therefore a trainer node is expected to have only one instance.
+ Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
Args:
diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py
index 761186b95ee5..391ffe7a91a9 100644
--- a/applications/Chat/coati/ray/utils.py
+++ b/applications/Chat/coati/ray/utils.py
@@ -116,7 +116,7 @@ def get_model_numel(model: nn.Module) -> int:
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
target_receivers = []
if num_senders <= num_receivers or allow_idle_sender:
- # a sender will send data to one or more than one receivers
+ # a sender will send data to one or more receivers
# a receiver only has one sender
for i in range(num_receivers):
if i % num_senders == sender_idx:
diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py
index 0812ba165286..e4d0a970740d 100644
--- a/applications/Chat/coati/trainer/sft.py
+++ b/applications/Chat/coati/trainer/sft.py
@@ -52,9 +52,13 @@ def _train(self, epoch: int):
for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"],
- attention_mask=batch["attention_mask"],
- labels=batch["labels"])
+ if "attention_mask" in batch:
+ outputs = self.model(batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"])
+ else:
+ outputs = self.model(batch["input_ids"],
+ labels=batch["labels"])
loss = outputs.loss
loss = loss / self.accumulation_steps
diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md
index 68b03be16a30..0a97ae72f9d0 100644
--- a/applications/Chat/evaluate/README.md
+++ b/applications/Chat/evaluate/README.md
@@ -348,7 +348,7 @@ For example, if you want to add a new metric `persuasiveness` into category `bra
How can I add a new UniEval evaluation metric?
-For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown and you may need some experiments to test whether the model is capable of evaluating this metric.
+For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown, and you may need some experiments to test whether the model is capable of evaluating this metric.
```python
if task == 'data2text':
diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/Chat/evaluate/gpt_evaluate.py
index f8cfb8d0f7e5..6fcbe63d0253 100644
--- a/applications/Chat/evaluate/gpt_evaluate.py
+++ b/applications/Chat/evaluate/gpt_evaluate.py
@@ -576,7 +576,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float:
for key, value in logprobs.items():
# Sometimes the key will be one byte of a unicode character which takes the form of "bytes:\\xe7".
- # It is meaningless and thus we don't calculate probability.
+ # It is meaningless, and thus we don't calculate probability.
if "bytes" in key:
continue
# results[0] is the score which corresponds to the key(predicted token).
@@ -621,7 +621,7 @@ def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[st
Args:
model_name: name of the model for saving evaluation results.
- gpt_evaluation_results: evaluations results for all of the model answers.
+ gpt_evaluation_results: evaluations results for all the model answers.
save_path: path to save GPT evaluation statistics.
"""
@@ -641,7 +641,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
Args:
model_name: name of the model for saving statistics.
- evaluations: evaluations for all of the model answers.
+ evaluations: evaluations for all the model answers.
save_path: path to save GPT evaluation statistics.
"""
@@ -663,7 +663,7 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
for evaluation in data:
for metric in metrics:
if evaluation["evaluation"][metric] == {}:
- # This means after 3 retries, the server still returns an error and we set the score to 0.
+ # This means after 3 retries, the server still returns an error, and we set the score to 0.
scores[metric].append(0)
elif evaluation["evaluation"][metric]["logprobs"] is not None:
scores[metric].append(
diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md
index 8b2edc48cd99..ada3a16296af 100644
--- a/applications/Chat/examples/community/peft/README.md
+++ b/applications/Chat/examples/community/peft/README.md
@@ -20,7 +20,7 @@ pip install .
For SFT training, just call train_peft_sft.py
-Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
+Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt
index 40e6edc7ea73..5d0f9f927d17 100644
--- a/applications/Chat/examples/requirements.txt
+++ b/applications/Chat/examples/requirements.txt
@@ -1,2 +1,3 @@
pandas>=1.4.1
sentencepiece
+colossalai==0.3.1
\ No newline at end of file
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index 7585cf3ed0da..f068ea2bf5de 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -9,13 +9,15 @@
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
+from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler
@@ -58,6 +60,8 @@ def train(args):
model = LlamaActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
+ elif args.model == 'chatglm':
+ model = ChatGLMActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -81,6 +85,9 @@ def train(args):
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
+ elif args.model == 'chatglm':
+ tokenizer = ChatGLMTokenizer.from_pretrained(
+ "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -99,7 +106,6 @@ def train(args):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
-
logger = get_dist_logger()
# configure dataset
@@ -185,7 +191,7 @@ def train(args):
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
+ parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None)
diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt
index e079f8a6038d..eb1a77875acb 100644
--- a/applications/Chat/requirements-test.txt
+++ b/applications/Chat/requirements-test.txt
@@ -1 +1,2 @@
pytest
+colossalai==0.3.1
\ No newline at end of file
diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt
index af7ff67861eb..e5f5ca0932a8 100644
--- a/applications/Chat/requirements.txt
+++ b/applications/Chat/requirements.txt
@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
-colossalai>=0.2.4
+colossalai==0.3.1
torch<2.0.0, >=1.12.1
langchain
tokenizers
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
index 64ea1178cd0d..f9dee1bae935 100644
--- a/applications/Chat/tests/test_dataset.py
+++ b/applications/Chat/tests/test_dataset.py
@@ -11,32 +11,46 @@
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [
{
- "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
- "input": "",
- "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id": 0
+ "instruction":
+ "Provide a list of the top 10 most popular mobile games in Asia",
+ "input":
+ "",
+ "output":
+ "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id":
+ 0
},
{
- "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
- "input": "",
- "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
- "id": 1
+ "instruction":
+ "Please provide an action plan for reducing carbon footprint on a corporate level",
+ "input":
+ "",
+ "output":
+ "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
+ "id":
+ 1
},
{
- "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
- "input": "",
- "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
- "id": 2
+ "instruction":
+ "Write a persuasive email to your boss explaining why you should have a pay raise",
+ "input":
+ "",
+ "output":
+ "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
+ "id":
+ 2
},
]
PROMPT_DATASET = [
{
- "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
- "id": 0
+ "instruction":
+ "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
+ "id":
+ 0
},
{
"instruction": "Write a descriptive paragraph about a memorable vacation you went on",
@@ -66,14 +80,14 @@ def make_tokenizer(model: str):
elif model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.unk_token
+ elif model == "chatglm":
+ tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
else:
raise ValueError(f"Unsupported model '{model}'")
return tokenizer
-def check_content(input_ids_stripped: torch.Tensor,
- tokenizer: PreTrainedTokenizer,
- model: str):
+def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
if model == "opt":
# NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt.
assert input_ids_stripped[0] == tokenizer.eos_token_id
@@ -81,22 +95,25 @@ def check_content(input_ids_stripped: torch.Tensor,
elif model == "llama":
assert input_ids_stripped[0] == tokenizer.bos_token_id
input_ids_stripped = input_ids_stripped[1:]
-
+ elif model == "chatglm":
+ assert input_ids_stripped[0] == tokenizer.bos_token_id
+ assert input_ids_stripped[-1] == tokenizer.eos_token_id
+ input_ids_stripped = input_ids_stripped[1:-1]
assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
assert input_ids_stripped != tokenizer.sep_token_id
assert input_ids_stripped != tokenizer.cls_token_id
- assert input_ids_stripped != tokenizer.mask_token_id
+ if model == "chatglm":
+ assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
+ else:
+ assert input_ids_stripped != tokenizer.mask_token_id
-@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("max_length", [32, 1024])
@pytest.mark.parametrize("max_datasets_size", [2])
-def test_prompt_dataset(model: str,
- max_datasets_size: int,
- max_length: int):
+def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "prompt_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
@@ -119,19 +136,12 @@ def test_prompt_dataset(model: str,
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
-@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
-@pytest.mark.parametrize(["dataset_path", "subset"], [
- ("Anthropic/hh-rlhf", "harmless-base"),
- ("Dahoas/rm-static", None)
-])
+@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"),
+ ("Dahoas/rm-static", None)])
@pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024])
-def test_reward_dataset(model: str,
- dataset_path: str,
- subset: Optional[str],
- max_datasets_size: int,
- max_length: int):
+def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) \
and max_datasets_size <= len(data["test"])
@@ -188,15 +198,12 @@ def test_reward_dataset(model: str,
assert torch.all(r_mask)
-@pytest.mark.cpu
-@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024])
-def test_sft_dataset(model: str,
- dataset_path: Optional[str],
- max_dataset_size: int,
- max_length: int):
+def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
tokenizer = make_tokenizer(model)
if dataset_path == "yizhongw/self_instruct":
data = load_dataset(dataset_path, "super_natural_instructions")
@@ -213,6 +220,19 @@ def test_sft_dataset(model: str,
max_length=max_length)
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ for i in range(max_dataset_size):
+ assert isinstance(sft_dataset[i], dict)
+ assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
+ input_ids = sft_dataset[i]["input_ids"]
+ labels = sft_dataset[i]["labels"]
+ assert input_ids.shape == labels.shape == torch.Size([max_length])
+
+ ignore_mask = labels == IGNORE_INDEX
+ assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
+ check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
+ return
+
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
@@ -232,10 +252,7 @@ def test_sft_dataset(model: str,
if __name__ == "__main__":
- test_sft_dataset(model="bloom",
- dataset_path="yizhongw/self_instruct",
- max_dataset_size=2,
- max_length=256)
+ test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
test_reward_dataset(model="gpt2",
dataset_path="Anthropic/hh-rlhf",
@@ -246,3 +263,4 @@ def test_sft_dataset(model: str,
test_prompt_dataset(model="opt",
max_datasets_size=2,
max_length=128)
+
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
index bd6b3e8a5ad1..b98b3615cd28 100644
--- a/applications/Chat/tests/test_models.py
+++ b/applications/Chat/tests/test_models.py
@@ -9,22 +9,26 @@
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
+from coati.models.chatglm import ChatGLMActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
-
-@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
-@pytest.mark.parametrize("actor_maker", [
- lambda: BLOOMActor(),
- lambda: GPTActor(),
+@pytest.mark.parametrize(
+ "actor_maker",
+ [
+ lambda: BLOOMActor(),
+ lambda: GPTActor(),
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
- lambda: OPTActor()
+ lambda: OPTActor(),
+ # lambda: ChatGLMActor(),
])
+
@pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64,
"use_cache": True,
@@ -32,23 +36,15 @@
"temperature": 1.0,
"top_k": 50,
}])
-def test_generation(actor_maker: Callable[[], Actor],
- batch_size: int,
- seq_len: int,
- generate_kwargs: Dict[str, Any]
- ):
+def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
-@pytest.mark.cpu
def test_utils():
- fn_input = {
- "tensor": torch.ones((10, )),
- "mask": torch.randint(0, 2, (10, ))
- }
+ fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0))
@@ -56,14 +52,14 @@ def test_utils():
batch_size = 4
num_labels = 10
fn_input = {
- "r": torch.ones((batch_size, )),
+ "r": torch.ones((batch_size,)),
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels))
}
fn_output = compute_reward(**fn_input)
- assert fn_output.shape == (batch_size, )
+ assert fn_output.shape == (batch_size,)
batch_size = 4
seq_len = 32
@@ -80,17 +76,11 @@ def test_utils():
assert fn_output.shape == (batch_size, num_actions)
-@pytest.mark.cpu
@pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4])
-def test_lora(lora_rank: int,
- num_dim: int,
- num_layers: int):
- model = nn.ModuleList(
- [nn.Linear(num_dim, num_dim)
- for _ in range(num_layers)]
- )
+def test_lora(lora_rank: int, num_dim: int, num_layers: int):
+ model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers):
@@ -103,8 +93,7 @@ def test_lora(lora_rank: int,
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
- assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
- lora_model[i].lora_B @ lora_model[i].lora_A)
+ assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim)
for i in range(num_layers):
@@ -120,21 +109,22 @@ def test_lora(lora_rank: int,
lora_model[i].lora_B @ lora_model[i].lora_A)
-@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
-@pytest.mark.parametrize("models_maker", [
- lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
- lambda: (GPTActor(), GPTCritic(), GPTRM()),
+@pytest.mark.parametrize(
+ "models_maker",
+ [
+ lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
+ lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()),
+ lambda: (ChatGLMActor(), None, None),
])
@torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
batch_size: int,
seq_len: int):
-
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
@@ -150,29 +140,36 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
}
actor, critic, rm = models_maker()
+ if isinstance(actor, ChatGLMActor):
+ actor = actor.float()
+ tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
+ chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
+ actor_input ={
+ "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
+ "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
+ }
assert isinstance(actor, Actor)
base_actor_model = get_base_model(actor)
- assert isinstance(critic, Critic)
- base_critic_model = get_base_model(critic)
- assert isinstance(rm, RewardModel)
- base_rm_model = get_base_model(rm)
-
actor_output = actor(**actor_input)
- critic_output = critic(**critic_input)
- rm_output = rm(**rm_input)
-
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
- assert critic_output.shape == (batch_size, )
- assert rm_output.shape == (batch_size, )
+
+ if critic:
+ assert isinstance(critic, Critic)
+ base_critic_model = get_base_model(critic)
+ critic_output = critic(**critic_input)
+ assert critic_output.shape == (batch_size, )
+
+ if rm:
+ assert isinstance(rm, RewardModel)
+ base_rm_model = get_base_model(rm)
+ rm_output = rm(**rm_input)
+ assert rm_output.shape == (batch_size, )
-@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100])
-def test_loss(batch_size: int,
- seq_len: int,
- num_labels: int):
+def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
@@ -182,54 +179,43 @@ def test_loss(batch_size: int,
loss = PolicyLoss()
loss_input = {
- "log_probs": torch.randn(batch_size, ),
- "old_log_probs": torch.randn(batch_size, ),
- "advantages": torch.randn(batch_size, )
+ "log_probs": torch.randn(batch_size,),
+ "old_log_probs": torch.randn(batch_size,),
+ "advantages": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = ValueLoss()
loss_input = {
- "values": torch.randn(batch_size, ),
- "old_values": torch.randn(batch_size, ),
- "reward": torch.randn(batch_size, )
+ "values": torch.randn(batch_size,),
+ "old_values": torch.randn(batch_size,),
+ "reward": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = LogSigLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size, ),
- "reject_reward": torch.randn(batch_size, ),
+ "chosen_reward": torch.randn(batch_size,),
+ "reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
loss = LogExpLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size, ),
- "reject_reward": torch.randn(batch_size, ),
+ "chosen_reward": torch.randn(batch_size,),
+ "reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
if __name__ == "__main__":
- generate_kwargs = dict(max_length=40,
- use_cache=True,
- do_sample=True,
- temperature=1.0,
- top_k=50)
- test_generation(lambda: LlamaActor(),
- batch_size=4,
- seq_len=32,
- generate_kwargs=generate_kwargs)
+ generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
+ test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2)
- test_models(models_maker=lambda: (BLOOMActor(),
- BLOOMCritic(),
- BLOOMRM()),
- batch_size=8,
- seq_len=128)
+ test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
- test_loss(batch_size=8, seq_len=128, num_labels=100)
+ test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py
index d0c328e134ff..5b9f74b132f3 100644
--- a/colossalai/auto_parallel/offload/base_offload_module.py
+++ b/colossalai/auto_parallel/offload/base_offload_module.py
@@ -4,7 +4,7 @@
import torch
import torch.nn as nn
-from colossalai.nn.parallel.data_parallel import _cast_float
+from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
index 8e06cec4f463..730a90d74cf8 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
@@ -1,5 +1,4 @@
class Registry:
- # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
def __init__(self, name):
self.name = name
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
index adb8f62a5084..fb9dae7c9650 100644
--- a/colossalai/booster/booster.py
+++ b/colossalai/booster/booster.py
@@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
-from typing import Any, Callable, Iterator, List, Optional, Union
+from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import torch
import torch.nn as nn
@@ -24,32 +24,36 @@ class Booster:
Booster is a high-level API for training neural networks. It provides a unified interface for
training with different precision, accelerator, and plugin.
- Examples:
- ```python
- colossalai.launch(...)
- plugin = GeminiPlugin(...)
- booster = Booster(precision='fp16', plugin=plugin)
-
- model = GPT2()
- optimizer = HybridAdam(model.parameters())
- dataloader = Dataloader(Dataset)
- lr_scheduler = LinearWarmupScheduler()
- criterion = GPTLMLoss()
-
- model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
-
- for epoch in range(max_epochs):
- for input_ids, attention_mask in dataloader:
- outputs = model(input_ids, attention_mask)
- loss = criterion(outputs.logits, input_ids)
- booster.backward(loss, optimizer)
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
- ```
+
+ ```python
+ # Following is pseudocode
+
+ colossalai.launch(...)
+ plugin = GeminiPlugin(...)
+ booster = Booster(precision='fp16', plugin=plugin)
+
+ model = GPT2()
+ optimizer = HybridAdam(model.parameters())
+ dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ lr_scheduler = LinearWarmupScheduler()
+ criterion = GPTLMLoss()
+
+ model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
+
+ for epoch in range(max_epochs):
+ for input_ids, attention_mask in dataloader:
+ outputs = model(input_ids.cuda(), attention_mask.cuda())
+ loss = criterion(outputs.logits, input_ids)
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ ```
Args:
- device (str or torch.device): The device to run the training. Default: 'cuda'.
+ device (str or torch.device): The device to run the training. Default: None.
+ If plugin is not used or plugin doesn't control the device,
+ this argument will be set as training device ('cuda' will be used if argument is None).
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
@@ -57,8 +61,8 @@ class Booster:
"""
def __init__(self,
- device: str = 'cuda',
- mixed_precision: Union[MixedPrecision, str] = None,
+ device: Optional[str] = None,
+ mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None) -> None:
if plugin is not None:
assert isinstance(
@@ -68,13 +72,16 @@ def __init__(self,
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
- warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
+ if device is not None:
+ warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
+ device = device or 'cuda'
self.accelerator = Accelerator(device)
# set precision
if self.plugin and self.plugin.control_precision():
- warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
+ if mixed_precision is not None:
+ warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
@@ -105,14 +112,19 @@ def boost(
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
- Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
+ Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
Args:
- model (nn.Module): The model to be boosted.
- optimizer (Optimizer): The optimizer to be boosted.
- criterion (Callable): The criterion to be boosted.
- dataloader (DataLoader): The dataloader to be boosted.
- lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
+ model (nn.Module): Convert model into a wrapped model for distributive training.
+ The model might be decorated or partitioned by plugin's strategy after execution of this method.
+ optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
+ The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
+ criterion (Callable, optional): The function that calculates loss. Defaults to None.
+ dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
+ lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
+
+ Returns:
+ List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
@@ -133,10 +145,10 @@ def boost(
return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
- """Backward pass.
+ """Execution of backward during training step.
Args:
- loss (torch.Tensor): The loss to be backpropagated.
+ loss (torch.Tensor): The loss for backpropagation.
optimizer (Optimizer): The optimizer to be updated.
"""
# TODO(frank lee): implement this method with plugin
@@ -146,11 +158,33 @@ def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Optimizer,
+ optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
- return_outputs: bool = False) -> dict:
- # run pipeline forward backward pass
- # return loss or outputs if needed
+ return_outputs: bool = False) -> Dict[str, Any]:
+ """
+ Execute forward & backward when utilizing pipeline parallel.
+ Return loss or Huggingface style model outputs if needed.
+
+ Warning: This function is tailored for the scenario of pipeline parallel.
+ As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
+ when doing pipeline parallel training with booster, which will cause unexpected errors.
+
+ Args:
+ data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
+ 1. wrap the dataloader to iterator through: iter(dataloader)
+ 2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
+ model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
+ criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ 'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
+ optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
+ return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
+ return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
+
+ Returns:
+ Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
+ ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
+ ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
+ """
assert isinstance(self.plugin,
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
@@ -170,7 +204,7 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model, optimizer)
- def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
+ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.
Args:
@@ -190,7 +224,7 @@ def save_model(self,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
- use_safetensors: bool = False):
+ use_safetensors: bool = False) -> None:
"""Save model to checkpoint.
Args:
@@ -198,7 +232,7 @@ def save_model(self,
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
- If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
+ If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
@@ -213,7 +247,7 @@ def save_model(self,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)
- def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint.
Args:
@@ -232,7 +266,7 @@ def save_optimizer(self,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
- size_per_shard: int = 1024):
+ size_per_shard: int = 1024) -> None:
"""
Save optimizer to checkpoint.
@@ -249,7 +283,7 @@ def save_optimizer(self,
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
- def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.
Args:
@@ -258,7 +292,7 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
- def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Load lr scheduler from checkpoint.
Args:
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 54d815ce701e..de03ba27bfda 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -15,6 +15,7 @@
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
+ save_config_file,
save_state_dict,
save_state_dict_shards,
)
@@ -107,6 +108,7 @@ def save_sharded_model(self,
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
+ save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 28a19af0ce91..d15245523226 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1,24 +1,28 @@
import random
from contextlib import nullcontext
-from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
+from functools import partial
+from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
-from torch.nn import Module
+from torch.nn import Module, SyncBatchNorm
+from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
-from colossalai.checkpoint_io import CheckpointIO
+from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
@@ -26,26 +30,54 @@
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
+def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
+ if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
+ return x.to(dtype)
+ return x
+
+
class HybridParallelModule(ModelWrapper):
- def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
+ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
+ ddp_config: dict, custom_policy: Policy) -> None:
+
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
+
shardformer = ShardFormer(shard_config)
- module, self.shared_params = shardformer.optimize(module)
- # TODO(ver217): add input type cast
+ if custom_policy is not None:
+ assert isinstance(custom_policy, object)
+ module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
+
+ # setting process groups for shared parameters
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
+
+ # setting mixed_precision
+ self.mixed_precision = None
if precision == 'fp16':
- module = module.half().cuda()
+ self.mixed_precision = torch.float16
elif precision == 'bf16':
- module = module.to(dtype=torch.bfloat16).cuda()
- else:
- module = module.cuda() # train without AMP
- # TODO(ver217): support TP+DP
+ self.mixed_precision = torch.bfloat16
+ if self.mixed_precision is not None:
+ module = module.to(self.mixed_precision)
+ module = module.cuda()
+
+ # setting input type cast when using mixed precision
+ self.convert_fn = None
+ if self.mixed_precision is not None:
+ self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
+
+ # setting ddp configs
+ if use_ddp:
+ # convert model to sync bn
+ module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
+ # wrap the model with PyTorch DDP
+ module = DDP(module, process_group=dp_group, **ddp_config)
+
super().__init__(module)
def sync_shared_params(self):
@@ -68,19 +100,62 @@ def sync_grads(self):
dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size())
+ def forward(self, *args, **kwargs):
+ if self.convert_fn is not None:
+ args = tree_map(self.convert_fn, args)
+ kwargs = tree_map(self.convert_fn, kwargs)
+ return super().forward(*args, **kwargs)
+
+ def unwrap(self):
+ module = super().unwrap()
+ if isinstance(module, DDP):
+ module = module.module
+ return module
+
+
+def get_param_info(optim: Optimizer):
+ # Get a backup of necessary information of parameters for future use, which includes:
+ # 1. A complete param_group, with params in the form of param_id
+ # 2. A mapping from param address (obtained using id(param)) to integer param_id
+ # 3. A mapping from integer param_id to param address.
+ # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
+ # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
+
+ if optim is None:
+ return {}
+ param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
+ start_index = 0
+ for group in optim.param_groups:
+
+ packed_group = {k: v for k, v in group.items() if k != 'params'}
+ packed_group['params'] = []
+
+ for param_id, param in enumerate(group['params'], start_index):
+ original_shape = param.shape if isinstance(param, torch.Tensor) else None
+ packed_group['params'].append(param_id)
+ param_info['param2id'][id(param)] = param_id
+ param_info['id2param'][param_id] = id(param)
+ param_info['param2shape'][id(param)] = original_shape
+
+ param_info['param_groups'].append(packed_group)
+ start_index += len(group['params'])
+
+ return param_info
+
def init_pipeline_optimizer(optim: Optimizer, model: Module):
- params = set(model.parameters())
+ model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
- params = [p for p in group['params'] if p in params]
+ params = [p for p in group['params'] if p in model_params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper):
- def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
+ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
+ self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
@@ -92,6 +167,7 @@ def __init__(self,
optim: Optimizer,
model: Module,
use_pipeline: bool,
+ param_info: OrderedDict,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
@@ -101,6 +177,7 @@ def __init__(self,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0):
+ self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
@@ -114,6 +191,7 @@ def __init__(
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
+ param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
@@ -131,6 +209,7 @@ def __init__(
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
+ self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
@@ -140,34 +219,104 @@ def __init__(
class HybridParallelPlugin(PipelinePluginBase):
+ """
+ Plugin for Hybrid Parallel Training.
+ Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
+ The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
+
+ Example:
+ >>> from colossalai.booster import Booster
+ >>> from colossalai.booster.plugin import HybridParallelPlugin
+
+ >>> model, train_dataset, optimizer, criterion = ...
+ >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
+
+ >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ >>> booster = Booster(plugin=plugin)
+ >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
+
+ Args:
+ tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
+ pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
+ precision (str, optional): Specifies the precision of parameters during training.
+ Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
+ Defaults to 'fp16'.
+ zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
+ When set to 0, ZeRO will not be used. Defaults to 0.
+ enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
+ Currently all the optimization methods include fused normalization, flash attention and JIT.
+ Defaults to False.
+ enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
+ num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
+ microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
+ Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
+ If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
+ initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
+ min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
+ growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
+ backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
+ growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
+ hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
+ max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
+ max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
+ broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
+ ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
+ find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
+ check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
+ gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
+ static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
+ zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
+ cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
+ communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
+ overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
+ """
+
+ def __init__(self,
+ tp_size: int,
+ pp_size: int,
+ precision: str = 'fp16',
+ zero_stage: int = 0,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_parallelism: bool = False,
+ enable_sequence_overlap: bool = False,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ broadcast_buffers: bool = True,
+ ddp_bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ zero_bucket_size_in_m: int = 12,
+ cpu_offload: bool = False,
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ custom_policy: Policy = None) -> None:
- def __init__(
- self,
- tp_size: int,
- pp_size: int,
- precision: str = 'fp16',
- zero_stage: int = 0,
- cpu_offload: bool = False,
- enable_all_optimization: bool = False,
- enable_fused_normalization: bool = False,
- enable_flash_attention: bool = False,
- enable_jit_fused: bool = False,
- num_microbatches: Optional[int] = None,
- initial_scale: float = 2**16,
- min_scale: float = 1,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- max_scale: float = 2**32,
- max_norm: float = 0,
- ) -> None:
super().__init__()
assert dist.get_world_size() % (
tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
- # TODO(ver217): support zero
- assert zero_stage == 0, 'zero is not support yet'
+
+ if enable_sequence_parallelism:
+ assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
+
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@@ -178,24 +327,31 @@ def __init__(
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
+ self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
+ self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
- assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
+ assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
- self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
+ self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
+ num_microbatches=num_microbatches,
+ microbatch_size=microbatch_size)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
+ self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
- enable_jit_fused=self.enable_jit_fused)
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism,
+ enable_sequence_overlap=enable_sequence_overlap)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
@@ -205,6 +361,20 @@ def __init__(
min_scale=min_scale,
max_scale=max_scale,
)
+
+ self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=ddp_bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph)
+
+ self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(self.zero_stage == 2))
+
self.max_norm = max_norm
@property
@@ -237,47 +407,59 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
+ param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
- model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
+ use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
+ model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
+ self.ddp_config, self.custom_policy)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
+ self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
+ optimizer.master_to_working_map)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
- use_pipeline=self.enable_pipeline_parallelism)
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info)
else:
+ assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
+ assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
- partition_grad=(self.zero_stage == 2),
- cpu_offload=self.cpu_offload,
+ param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
+ **self.zero_config,
**self.amp_config)
+ self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
+ optimizer._param_store.master_to_working_param)
+
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self,
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
- HybridParallelZeroOptimizer],
+ optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
+ HybridParallelZeroOptimizer]] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
# return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
- outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
+ outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
return_outputs)
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
@@ -339,7 +521,8 @@ def seed_worker(worker_id):
**_kwargs)
def get_checkpoint_io(self) -> CheckpointIO:
- return None
+ self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 616b218b2070..9adb4beec9b9 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -3,6 +3,7 @@
import warnings
from functools import partial
from pathlib import Path
+from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
@@ -17,12 +18,17 @@
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
+ load_param_groups_into_optimizer,
+ load_shard_state_dict,
+ load_states_into_optimizer,
save_param_groups,
save_state_dict,
+ sharded_optimizer_loading_epilogue,
+ unwrap_optimizer,
)
-from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
-from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@@ -39,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
+class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
+
+ def __init__(self, module: nn.Module, precision: str) -> None:
+ super().__init__(module)
+ self.dtype = None
+ if precision == 'fp16':
+ self.dtype = torch.float16
+ elif precision == 'bf16':
+ self.dtype = torch.bfloat16
+ if self.dtype is not None:
+ module = module.to(self.dtype)
+ module = module.to(get_current_device())
+ self.module = module
+ self.convert_fn = None
+ if self.dtype is not None:
+ self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
+
+ def forward(self, *args, **kwargs):
+ if self.convert_fn is not None:
+ args = tree_map(self.convert_fn, args)
+ kwargs = tree_map(self.convert_fn, kwargs)
+ return super().forward(*args, **kwargs)
+
+ def unwrap(self):
+ # TODO(ver217): this is a workaround for loading model
+ return self
+
+
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@@ -126,44 +160,70 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
- super().load_sharded_optimizer(optimizer, index_file_path, prefix)
- current_rank_state_dict = optimizer.optim.state_dict()['state']
- for param_idx, state in current_rank_state_dict.items():
- for k, v in state.items():
- if isinstance(v, torch.Tensor) and k != 'step':
- padding_size = (self.coordinator.world_size -
- v.numel() % self.coordinator.world_size) % self.coordinator.world_size
- with torch.no_grad():
- v = v.flatten()
- if padding_size > 0:
- v = torch.nn.functional.pad(v, [0, padding_size])
- v_list = v.split(v.numel() // self.coordinator.world_size)
- current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
-
-
-class LowLevelZeroModel(ModelWrapper):
-
- def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
- super().__init__(module)
- self.dtype = None
- if precision == 'fp16':
- self.dtype = torch.float16
- elif precision == 'bf16':
- self.dtype = torch.bfloat16
- module = zero_model_wrapper(module, zero_stage=stage)
- if self.dtype is not None:
- module = module.to(self.dtype)
- module = module.to(get_current_device())
- self.module = module
- self.convert_fn = None
- if self.dtype is not None:
- self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
-
- def forward(self, *args, **kwargs):
- if self.convert_fn is not None:
- args = tree_map(self.convert_fn, args)
- kwargs = tree_map(self.convert_fn, kwargs)
- return super().forward(*args, **kwargs)
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = unwrap_optimizer(optimizer)
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory.')
+ id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
+
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ for shard_file in checkpoint_files:
+ state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ # shard state dict
+ for param_idx, state in state_dict.items():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+ padding_size = (self.coordinator.world_size -
+ v.numel() % self.coordinator.world_size) % self.coordinator.world_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ v_list = v.split(v.numel() // self.coordinator.world_size)
+ state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
+ load_states_into_optimizer(optimizer, state_dict, id_map)
+
+ sharded_optimizer_loading_epilogue(optimizer)
+
+ def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
+ use_safetensors: bool):
+ assert isinstance(model, LowLevelZeroModel)
+ super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
+
+ def save_sharded_model(self,
+ model: nn.Module,
+ checkpoint_path: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False):
+ assert isinstance(model, LowLevelZeroModel)
+ super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
+ use_safetensors)
+
+ def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
+ assert isinstance(model, LowLevelZeroModel)
+ super().load_unsharded_model(model.module, checkpoint, strict)
+ model.update_master_params()
+
+ def load_sharded_model(self,
+ model: LowLevelZeroModel,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True):
+ assert isinstance(model, LowLevelZeroModel)
+ super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
+ model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase):
@@ -223,22 +283,24 @@ def __init__(
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
-
+ assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage
self.precision = precision
- self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
- communication_dtype=communication_dtype,
- overlap_communication=overlap_communication,
- cpu_offload=cpu_offload)
- self.optim_kwargs = dict(initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- max_norm=max_norm,
- norm_type=norm_type)
+ self.zero_optim_kwargs = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ clip_grad_norm=max_norm,
+ reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(stage == 2),
+ )
self.verbose = verbose
# set class name with stage, for better error message
@@ -269,15 +331,15 @@ def configure(
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
- model = LowLevelZeroModel(model, self.stage, self.precision)
+ model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
- optimizer = zero_optim_wrapper(model.unwrap(),
- optimizer,
- optim_config=self.zero_optim_config,
- **self.optim_kwargs,
- verbose=self.verbose)
+ optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
+ **self.zero_optim_kwargs,
+ verbose=self.verbose)
+ # inject update_master_params
+ model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py
index 67ade9330f5b..f52844db082f 100644
--- a/colossalai/booster/plugin/pp_plugin_base.py
+++ b/colossalai/booster/plugin/pp_plugin_base.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Any, Callable, Iterator
+from typing import Any, Callable, Iterator, Optional
import torch
@@ -15,7 +15,7 @@ def execute_pipeline(self,
data_iter: Iterator,
model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: OptimizerWrapper,
+ optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
pass
diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py
index c25048e25754..e1aa6543ef39 100644
--- a/colossalai/checkpoint_io/__init__.py
+++ b/colossalai/checkpoint_io/__init__.py
@@ -1,5 +1,6 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
+from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
-__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
+__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index 83e4bdcc863b..faaf1d22722a 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -23,6 +23,7 @@
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
+ save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
@@ -78,8 +79,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
- del state_dict
- gc.collect()
sharded_optimizer_loading_epilogue(optimizer)
@@ -185,6 +184,7 @@ def save_sharded_model(self,
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
new file mode 100644
index 000000000000..270fd8564754
--- /dev/null
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -0,0 +1,704 @@
+import copy
+import gc
+import logging
+import os
+from pathlib import Path
+from shutil import rmtree
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+
+from colossalai.cluster import DistCoordinator
+from colossalai.interface import OptimizerWrapper
+
+from .general_checkpoint_io import GeneralCheckpointIO
+from .index_file import CheckpointIndexFile
+from .utils import (
+ StateDictSharder,
+ gather_distributed_param,
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ is_safetensors_available,
+ load_shard_state_dict,
+ load_state_dict_into_model,
+ load_states_into_optimizer,
+ save_config_file,
+ save_param_groups,
+ save_state_dict_shards,
+ search_tp_partition_dim,
+ sharded_optimizer_loading_epilogue,
+)
+
+try:
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
+except ImportError:
+ _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
+
+
+class HybridParallelCheckpointIO(GeneralCheckpointIO):
+ """
+ CheckpointIO for Hybrid Parallel Training.
+
+ Args:
+ dp_group (ProcessGroup): Process group along data parallel dimension.
+ pp_group (ProcessGroup): Process group along pipeline parallel dimension.
+ tp_group (ProcessGroup): Process group along tensor parallel dimension.
+ zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
+ verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
+ """
+
+ def __init__(self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ verbose: bool = True) -> None:
+ super().__init__()
+ self.dp_group = dp_group
+ self.pp_group = pp_group
+ self.tp_group = tp_group
+ self.dp_rank = dist.get_rank(self.dp_group)
+ self.tp_rank = dist.get_rank(self.tp_group)
+ self.pp_rank = dist.get_rank(self.pp_group)
+ self.dp_size = dist.get_world_size(dp_group)
+ self.pp_size = dist.get_world_size(pp_group)
+ self.tp_size = dist.get_world_size(tp_group)
+ self.use_zero = (zero_stage > 0)
+ self.verbose = verbose
+ self.working_to_master_map = None
+ self.master_to_working_map = None
+ self.coordinator = DistCoordinator()
+
+ @staticmethod
+ def _model_sharder(model: nn.Module,
+ prefix: str = '',
+ keep_vars: bool = False,
+ size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
+ # An internel method that breaks state_dict of model into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+
+ # Save parameters.
+ for name, param in model.named_parameters():
+ if param is None:
+ continue
+ # Gather tensor pieces when using tensor parallel.
+ param_ = gather_distributed_param(param, keep_vars=False)
+ block, block_size = state_dict_sharder.append_param(prefix + name, param_)
+ if block is not None:
+ yield block, block_size
+
+ # Save buffers.
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in model._non_persistent_buffers_set:
+ buffer = buf if keep_vars else buf.detach()
+ block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
+ if block is not None:
+ yield block, block_size
+
+ # Save extra states.
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if getattr(model.__class__, "get_extra_state",
+ torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ extra_state = model.get_extra_state()
+ block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ @staticmethod
+ def _optimizer_sharder(optimizer: OptimizerWrapper,
+ use_zero: bool,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
+ size_per_shard: int = 1024):
+
+ # An internel method that breaks state_dict of optimizer into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+ param_info = optimizer.param_info
+
+ for param, state in optimizer.optim.state.items():
+
+ if param is None:
+ continue
+
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ param_id = param_info['param2id'][id(working_param)]
+ original_shape = param_info['param2shape'][id(working_param)]
+ state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=dp_group,
+ tp_group=tp_group,
+ use_zero=use_zero,
+ inplace=False)
+
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_model(self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False) -> None:
+ """
+ Save sharded model checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
+ - Multiple files that store state tensors of models.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_model.-000XX.bin"
+
+
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a directory path.
+ gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
+ prefix (str, optional): Perfix of file to save. Defaults to None.
+ size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of model.
+ # So only let the device with dp_rank == 0 save the model.
+ if self.dp_rank != 0:
+ return
+
+ # Then collect the sharded parameters & buffers along tp_group.
+ # Only devices with tp_rank == 0 are responsible for model saving.
+ state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = (self.tp_rank == 0)
+
+ if self.pp_size == 1:
+ # When pipeline is not used, save the model shards as in general checkpointIO
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors)
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint)
+ if self.verbose:
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
+
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
+ weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ use_pp_format=True)
+ if control_saving:
+ assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ return
+
+ dist.barrier(self.pp_group)
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.pp_rank == 0:
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for weight, weight_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(weight, weight_filename)
+
+ final_index_file.write_index_file(final_index_file_path)
+ save_config_file(model, checkpoint)
+ rmtree(tmp_index_file_folder)
+ if self.verbose:
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}.")
+
+ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
+ """
+ Load sharded model with the given path to index file of checkpoint folder.
+
+ Args:
+ model (nn.Module): The model to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ strict (bool, optional): For name matching during loading state_dict. Defaults to False.
+ This argument should be manually set to False since params on same device might be stored in different files.
+ """
+
+ # Check whether the checkpoint uses safetensors.
+ use_safetensors = False
+ if "safetensors" in checkpoint_index_file.name:
+ use_safetensors = True
+
+ if use_safetensors and not is_safetensors_available():
+ raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ strict = False
+
+ # Load params & buffers to model.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+
+ def _load(name: str):
+ if name not in weight_map:
+ raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
+ filename = weight_map[name]
+
+ # If this param/buffer has been loaded before, directly return.
+ if filename in loaded_file:
+ return
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
+ missing_keys = []
+
+ load_state_dict_into_model(model,
+ state_dict,
+ missing_keys=missing_keys,
+ strict=strict,
+ load_sub_module=True)
+ loaded_file.add(filename)
+
+ # Load parameters.
+ for name, _ in model.named_parameters():
+ _load(name)
+
+ # Load buffers.
+ non_persistent_buffers = set()
+ for n, m in model.named_modules():
+ non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in non_persistent_buffers:
+ _load(name)
+
+ # Load extra states.
+ extra_state_key = _EXTRA_STATE_KEY_SUFFIX
+ if getattr(model.__class__, "get_extra_state",
+ torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ _load(extra_state_key)
+
+ # Update master params if mixed-precision training is enabled.
+ with torch.no_grad():
+ if self.working_to_master_map is not None:
+ for param in model.parameters():
+ if (param is None) or (id(param) not in self.working_to_master_map):
+ continue
+ master_param = self.working_to_master_map[id(param)]
+ if self.use_zero:
+ # master_param is sharded under Zero setting
+ padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
+ if padding_size > 0:
+ padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
+ else:
+ padded_param = param.data.view(-1)
+ sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
+ master_param.data.copy_(sharded_param.data)
+ else:
+ master_param.data.copy_(param.data)
+
+ if self.verbose:
+ logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def save_sharded_optimizer(self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024):
+ """
+ Save sharded optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files that store state tensors of optimizers.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_optim.-000XX.bin"
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file shard that store state tensors
+ """
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of states when zero is not used.
+ # In this case only let the device with dp_rank == 0 save the model.
+ if not self.use_zero and self.dp_rank != 0:
+ return
+
+ # Then collect the sharded states along dp_group(if using zero)/tp_group.
+ # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
+ state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
+ optimizer,
+ use_zero=self.use_zero,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ master_to_working_map=self.master_to_working_map,
+ size_per_shard=size_per_shard)
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
+
+ if self.pp_size == 1:
+ # When pipeline is not used, save the optimizer shards as in general checkpointIO
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving)
+
+ if control_saving:
+ # Store param groups.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+ # Store index file.
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ if self.verbose:
+ logging.info(f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
+
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True)
+
+ if control_saving:
+ assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ return
+
+ dist.barrier(self.pp_group)
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.pp_rank == 0:
+
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for param_id, state_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(param_id, state_filename)
+
+ # Store param groups.
+ final_index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+
+ final_index_file.write_index_file(final_index_file_path)
+ rmtree(tmp_index_file_folder)
+
+ if self.verbose:
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}.")
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
+ """
+ Load sharded optimizer with the given path to index file of checkpoint folder.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ prefix (str): Not used.
+ """
+
+ def _get_param_id_from_optimizer_param(param: torch.Tensor,
+ master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ return optimizer.param_info['param2id'][id(working_param)]
+
+ # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
+ # When Zero is used, the mapped parameter objects should be fp32 master parameters.
+ # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
+ id_map = {}
+ for pg in optimizer.optim.param_groups:
+ for param in pg['params']:
+ param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
+ id_map[param_id] = param
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory.')
+ saved_groups = torch.load(param_group_path)
+
+ updated_groups = []
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ # obtain updated param group
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({'param_groups': updated_groups})
+
+ # Load saved states to optimizer.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+ for pg in optimizer.optim.param_groups:
+ for param in pg['params']:
+ if param is None:
+ continue
+ param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
+ if param_id not in weight_map:
+ continue
+ filename = weight_map[param_id]
+
+ # If this param's states has been loaded before, directly return.
+ if filename in loaded_file:
+ continue
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+ load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
+ loaded_file.add(filename)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ device = param.device
+ if self.master_to_working_map is not None:
+ working_param = self.master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info['param2shape'][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(state,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device=device,
+ inplace=True)
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ if self.verbose:
+ logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save lr scheduler to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
+
+ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
+ master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
+ """
+ Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
+ This mapping can only be created when mixied precision is used.
+ The created mappings should be mappings from integer parameter addresses to parameter objects.
+
+ Args:
+ working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
+ master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
+ """
+ self.working_to_master_map = dict()
+ for k, v in working_to_master_map.items():
+ if isinstance(k, torch.Tensor):
+ self.working_to_master_map[id(k)] = v
+ elif isinstance(k, int):
+ self.working_to_master_map[k] = v
+ else:
+ raise ValueError(
+ f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
+
+ self.master_to_working_map = dict()
+ for k, v in master_to_working_map.items():
+ if isinstance(k, torch.Tensor):
+ self.master_to_working_map[id(k)] = v
+ elif isinstance(k, int):
+ self.master_to_working_map[k] = v
+ else:
+ raise ValueError(
+ f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
+
+ @staticmethod
+ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
+ dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
+ inplace: bool) -> OrderedDict:
+ """
+ With given parameter and its optimizer states, gather the complete optimizer state for saving.
+
+ Args:
+ state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
+ param (torch.Tensor): The given parameter. It should be working_param when using Zero.
+ original_shape (torch.Size): The size of parameter before sharding.
+ dp_group (ProcessGroup): The process group of data parallel.
+ tp_group (ProcessGroup): The process group of tensor parallel.
+ use_zero (bool): Whether Zero is used.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The complete optimizer state of given parameter.
+ """
+ dp_size = dist.get_world_size(dp_group)
+ tp_size = dist.get_world_size(tp_group)
+ current_shape = param.shape
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+
+ # First gather Zero shards.
+ if use_zero:
+ v = v.cuda()
+ gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
+ dist.all_gather(gather_tensor, v, group=dp_group)
+ v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
+
+ # Then gather TP shards.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
+ if partition_dim is not None:
+ gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
+ dist.all_gather(gather_tensor, v, group=tp_group)
+ v = torch.cat(gather_tensor, dim=partition_dim)
+
+ state_[k] = v.detach().clone().cpu()
+
+ return state_
+
+ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
+ original_shape: torch.Size, device: torch.device,
+ inplace: bool) -> OrderedDict:
+ """
+ With complete optimizer states of a specific parameter loaded from checkpoint,
+ slice out the sharded optimizer states kept by current device.
+
+ Args:
+ state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
+ current_shape (torch.Size): The size of parameter after sharding.
+ original_shape (torch.Size): The size of parameter before sharding.
+ device (torch.device): The destination device of loaded optimizer states.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The sharded optimizer state of the given parameter.
+ """
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+
+ # Shard state along tensor parallel group.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
+ if partition_dim is not None:
+ slice_size = current_shape[partition_dim]
+ v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
+
+ # Shard state along data parallel group when using Zero.
+ if self.use_zero:
+ padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ slice_size = v.numel() // self.dp_size
+ v = v.split(slice_size, dim=0)[self.dp_rank]
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 8837776aee4d..3441eca38ce7 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -1,4 +1,5 @@
# coding=utf-8
+import copy
import os
import re
from collections import abc as container_abcs
@@ -11,9 +12,14 @@
import torch.nn as nn
from torch.optim import Optimizer
-from colossalai.interface import OptimizerWrapper
+from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.tensor.d_tensor import is_distributed_tensor
+from colossalai.tensor.d_tensor import (
+ is_customized_distributed_tensor,
+ is_distributed_tensor,
+ to_global,
+ to_global_for_customized_distributed_tensor,
+)
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -88,8 +94,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False
+def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]:
+ """
+ Given the current shape of parameter and the shape of parameter before sharding,
+ return the dimension along which the parameter is sharded when using tensor parallel.
+ If tensor parallel is not used, return None.
+
+ Args:
+ current_shape (torch.Size): The current shape of parameter after sharding.
+ original_shape (torch.Size): The shape of parameter before sharding.
+ tp_size (int): The size of tp group.
+
+ Returns:
+ Optional[int]: The dimension along which parameter is partitioned.
+ """
+ partition_dim = None
+ for dim, length in enumerate(original_shape):
+ if length > current_shape[dim]:
+ partition_dim = dim
+ break
+ if partition_dim is not None:
+ assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
+ f"The parameter isn't evenly distributed among tensor parallel group: \
+ shape before sharding {original_shape}, shape after sharding {current_shape}"
+
+ return partition_dim
+
+
# ======================================
-# Helper functions for saving shard file
+# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
@@ -104,12 +137,97 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim
+class StateDictSharder:
+
+ def __init__(self, size_per_shard: int) -> None:
+ self.max_shard_size = size_per_shard
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
+
+ tensor_size = calculate_tensor_size(tensor)
+ ret_block = None
+ ret_block_size = 0
+
+ # before we return the current block and create a new block,
+ # we need to ensure that the current block is not empty
+ if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
+ ret_block = self.current_block
+ ret_block_size = self.current_block_size
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ self.current_block[name] = tensor
+ self.current_block_size += tensor_size
+ return ret_block, ret_block_size
+
+ def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
+
+ # A state might contain more than one tensors.
+ # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
+ state_size = 0
+ isDTensor = False
+ for state_tensor in state.values():
+
+ # When state_tensor is not of Tensor class,
+ # e.g., a SGD optimizer with momentum set to 0 can have None as state
+ # The calculation of tensor size should be skipped to avoid error.
+ if not isinstance(state_tensor, torch.Tensor):
+ continue
+
+ # If the states are stored as DTensors, mark isDTensor as true.
+ if is_distributed_tensor(state_tensor):
+ isDTensor = True
+ state_size += calculate_tensor_size(state_tensor)
+
+ ret_block = None
+ ret_block_size = 0
+
+ # directly return if state is stored as distributed tensor
+ if isDTensor:
+ return ret_block, ret_block_size
+
+ # before we return the current block and create a new block,
+ # we need to ensure that the current block is not empty
+ if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0:
+ ret_block = self.current_block
+ ret_block_size = self.current_block_size
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ self.current_block[param_id] = state
+ self.current_block_size += state_size
+ return ret_block, ret_block_size
+
+
+def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor:
+ """
+ Gather the complete parameter for saving if passed in param is distributed under tp setting.
+
+ Args:
+ param (torch.Tensor): A model parameter, might be d_tensor.
+ keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
+
+ Returns:
+ torch.Tensor: the complete parameter
+ """
+ param_ = param if keep_vars else param.detach()
+ if is_distributed_tensor(param_):
+ return to_global(param_)
+ elif is_customized_distributed_tensor(param_):
+ return to_global_for_customized_distributed_tensor(param_)
+ else:
+ return param_
+
+
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
- use_safetensors: bool = False) -> int:
+ use_safetensors: bool = False,
+ use_pp_format: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
@@ -117,18 +235,21 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
- is_master (bool): Whether current rank is master.
- use_safetensors (bool): Whether to use safetensors to save checkpoint.
+ is_master (bool): Whether current rank is main process.
+ use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
+ use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
'''
total_size = 0
+ shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
+ shard, current_size = shard_pair
if not is_master:
+ del shard
continue
- shard, current_size = shard_pair
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
@@ -137,6 +258,11 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
+ shard_filenames.append(shard_file)
+ del shard
+
+ # Clean folder, deleted unneeded files.
+ clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size
@@ -146,28 +272,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
- current_block = {}
- current_block_size = 0
+ state_dict_sharder = StateDictSharder(max_shard_size)
for key, weight in state_dict.items():
- ret_block = None
- ret_block_size = 0
if not is_distributed_tensor(weight):
- weight_size = calculate_tensor_size(weight)
-
- # If this weight is going to tip up over the maximal size, we split.
- if current_block_size + weight_size > max_shard_size and current_block_size > 0:
- ret_block = current_block
- ret_block_size = current_block_size
- current_block = {}
- current_block_size = 0
- current_block[key] = weight
- current_block_size += weight_size
+ block, block_size = state_dict_sharder.append_param(key, weight)
- if ret_block != None:
- yield ret_block, ret_block_size
+ if block != None:
+ yield block, block_size
- yield current_block, current_block_size
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
@@ -178,47 +293,212 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
-
- current_block = {}
- current_block_size = 0
+ state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state)
+ if block != None:
+ yield block, block_size
- ret_block = None
- ret_block_size = 0
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
- # A state might contain more than one tensors.
- # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
- state_size = 0
- isDTensor = False
- for state_tensor in state.values():
- # When state_tensor is not of Tensor class,
- # e.g., a SGD optimizer with momentum set to 0 can have None as state
- # The calculation of tensor size should be skipped to avoid error.
- if not isinstance(state_tensor, torch.Tensor):
- continue
+# ======================================
+# Helper functions for saving state dict
+# ======================================
- # If the states are stored as DTensors, mark isDTensor as true.
- if is_distributed_tensor(state_tensor):
- isDTensor = True
- state_size += calculate_tensor_size(state_tensor)
- if not isDTensor:
+def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
+ """
+ Save state dict to checkpoint.
+
+ Args:
+ state_dict (dict): state dict.
+ checkpoint_file_path (str): path to the checkpoint file.
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
+ """
+ if use_safetensors:
+ assert is_safetensors_available(), "safetensors is not available."
+ assert checkpoint_file_path.endswith('.safetensors'), \
+ "safetensors only supports .safetensors suffix for checkpoint file."
+ from safetensors.torch import save_file as safe_save_file
+ safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
+ else:
+ torch.save(state_dict, checkpoint_file_path)
+
+
+def save_param_groups(state_dict: dict, group_file_path: str) -> None:
+ """
+ Save information of param_groups to given file path.
+
+ Args:
+ state_dict (dict): state dict.
+ group_file_path (str): path to the group file.
+ """
+ param_groups = state_dict["param_groups"]
+ torch.save(param_groups, group_file_path)
+
+
+def clean_folder(checkpoint_path: str,
+ weights_name: str,
+ shard_filenames: List[str],
+ is_master: bool = True,
+ use_pp_format: bool = False):
+ """
+ Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
+
+ Args:
+ checkpoint_path (str): Path to the checkpoint directory.
+ weights_name (str): Decides the prefix of filenames of weight shards.
+ shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
+ is_master (bool, optional): Whether current rank is main process. Defaults to True.
+ use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
+
+ """
+ if is_master:
+ for filename in os.listdir(checkpoint_path):
+ full_filename = os.path.join(checkpoint_path, filename)
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+ filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
+ if not use_pp_format:
+ reg = re.compile(r"(.*?)-\d{5}")
+ else:
+ # When this checkpoint is created by pipeline parallel process, the pattern is a little different.
+ reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
+ if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
+ and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
+ os.remove(full_filename)
+
+
+def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
+ """
+ Save config.json/generation_config.json if model is a Huggingface pretrained model.
+ This method can only be called when a model is saved in a sharded way.
+
+ Args:
+ model (nn.Module): The model whose config should be saved if it's a huggingface model.
+ checkpoint_path (str): Path to the checkpoint directory.
+ is_master (bool): Whether current rank is main process.
+ """
+ try:
+ from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
+ from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
+ except ImportError:
+ return
+ if not isinstance(model, PreTrainedModel):
+ return
+
+ model = unwrap_huggingface_model(model)
+
+ # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
+ dtype = get_parameter_dtype(model)
+ model.config.torch_dtype = str(dtype).split(".")[1]
+
+ # Attach architecture to the config
+ model.config.architectures = [model.__class__.__name__]
+
+ # Save the config
+ if is_master:
+ model.config.save_pretrained(checkpoint_path)
+ if model.can_generate():
+ model.generation_config.save_pretrained(checkpoint_path)
+
+
+def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
+ """
+ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
+ only one tensor.
+
+ Args:
+ tensor (Tensor): tensor to be saved.
+ index_file (CheckpointIndexFile): path to the checkpoint file.
+ size_per_shard (int): size per shard in MB.
+ """
+ root_path = index_file.root_path
+ output_root_path = root_path.joinpath('dtensor')
+
+ # create directory
+ output_root_path.mkdir(exist_ok=True)
+
+ # save tensor to this directory
+ # TODO(YuliangLiu): get index of the tensor shard
+ # e.g. index =
+ index = 0
+
+ # save tensor to file
+ ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
+ ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
+
+ # dtensor ckpt file always contains only one tensor
+ state_dict = {name: tensor}
+ save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
+
+ # update the weight map
+ # * means all shards
+ ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
+ index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
+
+
+def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
+ """
+ Get checkpoint file suffix.
+
+ Args:
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
+
+ Returns:
+ str: checkpoint file suffix.
+ """
+ if use_safetensors:
+ return '.safetensors'
+ else:
+ return '.bin'
+
+
+def generate_checkpoint_shard_file_name(index: int,
+ total_number: int,
+ use_safetensors: bool,
+ prefix: str = None) -> str:
+ """
+ Generate checkpoint shard file name.
+
+ Args:
+ index (int): index of the shard.
+ total_number (int): total number of shards.
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
+ prefix (str): prefix of the shard file name. Default: None.
+
+ Returns:
+ str: checkpoint shard file name.
+ """
+ suffix = get_checkpoint_file_suffix(use_safetensors)
+
+ if prefix is None:
+ return f"{index:05d}-of-{total_number:05d}.{suffix}"
+ else:
+ return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
+
+
+def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
+ """
+ Generate dtensor file name.
- if current_block_size + state_size > max_shard_size and current_block_size > 0:
- ret_block = current_block
- ret_block_size = current_block_size
- current_block = {}
- current_block_size = 0
+ Args:
+ param_name (str): name of the distributed parameter.
+ index (int): index of the shard.
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
- current_block[param_id] = state
- current_block_size += state_size
+ Returns:
+ str: dtensor file name.
+ """
+ suffix = get_checkpoint_file_suffix(use_safetensors)
+ return f'{param_name}.{index}.{suffix}'
- if ret_block != None:
- yield ret_block, ret_block_size
- yield current_block, current_block_size
+# ========================================
+# Helper functions for loading state dict
+# ========================================
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
@@ -237,7 +517,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
- return torch.load(checkpoint_file)
+ return torch.load(checkpoint_file, map_location=torch.device('cpu'))
def load_state_dict_into_model(model: nn.Module,
@@ -297,7 +577,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
- saved_groups = torch.load(param_group_path)
+ saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
@@ -331,17 +611,21 @@ def update_group(group, new_group):
return id_map
-def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
+def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
- state_dict(dict): a mapping from tensor index (an integer)
+ state_dict(dict): A mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
- id_map(dict): a mapping from tensor index (an integer)
+ id_map(dict): A mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
+ strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False.
"""
+ # Ensure that the keys of state_dict are integers.
+ state_dict = {int(k): v for k, v in state_dict.items()}
+
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
@@ -368,7 +652,7 @@ def cast(param, value, key=None):
if k in id_map:
param = id_map[k]
new_states[param] = cast(param, v)
- else:
+ elif not strict:
new_states[k] = v
optimizer.state.update(new_states)
@@ -386,165 +670,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
optimizer.defaults.setdefault('differentiable', False)
-# ======================================
-# Helper functions for saving state dict
-# ======================================
-
-
-def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
- """
- Save state dict to checkpoint.
-
- Args:
- state_dict (dict): state dict.
- checkpoint_file_path (str): path to the checkpoint file.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
- """
- if use_safetensors:
- assert is_safetensors_available(), "safetensors is not available."
- assert checkpoint_file_path.endswith('.safetensors'), \
- "safetensors only supports .safetensors suffix for checkpoint file."
- from safetensors.torch import save_file as safe_save_file
- safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
- else:
- torch.save(state_dict, checkpoint_file_path)
-
-
-def save_param_groups(state_dict: dict, group_file_path: str) -> None:
- """
- Save information of param_groups to given file path.
-
- Args:
- state_dict (dict): state dict.
- group_file_path (str): path to the group file.
- """
- param_groups = state_dict["param_groups"]
- torch.save(param_groups, group_file_path)
-
-
-def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
- """
- Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
- only one tensor.
-
- Args:
- tensor (Tensor): tensor to be saved.
- index_file (CheckpointIndexFile): path to the checkpoint file.
- size_per_shard (int): size per shard in MB.
- """
- root_path = index_file.root_path
- output_root_path = root_path.joinpath('dtensor')
-
- # create directory
- output_root_path.mkdir(exist_ok=True)
-
- # save tensor to this directory
- # TODO(YuliangLiu): get index of the tensor shard
- # e.g. index =
- index = 0
-
- # save tensor to file
- ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
- ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
-
- # dtensor ckpt file always contains only one tensor
- state_dict = {name: tensor}
- save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
-
- # update the weight map
- # * means all shards
- ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
- index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
-
-
-def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
- """
- Get checkpoint file suffix.
-
- Args:
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
-
- Returns:
- str: checkpoint file suffix.
- """
- if use_safetensors:
- return '.safetensors'
- else:
- return '.bin'
-
-
-def generate_checkpoint_shard_file_name(index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None) -> str:
- """
- Generate checkpoint shard file name.
-
- Args:
- index (int): index of the shard.
- total_number (int): total number of shards.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
- prefix (str): prefix of the shard file name. Default: None.
-
- Returns:
- str: checkpoint shard file name.
- """
- suffix = get_checkpoint_file_suffix(use_safetensors)
-
- if prefix is None:
- return f"{index:05d}-of-{total_number:05d}.{suffix}"
- else:
- return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
-
-
-def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
- """
- Generate dtensor file name.
-
- Args:
- param_name (str): name of the distributed parameter.
- index (int): index of the shard.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
-
- Returns:
- str: dtensor file name.
- """
- suffix = get_checkpoint_file_suffix(use_safetensors)
- return f'{param_name}.{index}.{suffix}'
-
-
-def save_state_dict_as_shard(
- state_dict: dict,
- checkpoint_path: str,
- index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None,
-) -> None:
- """
- Save state dict as shard.
-
- Args:
- state_dict (dict): state dict.
- checkpoint_path (str): path to the checkpoint file.
- index (int): index of the shard.
- total_number (int): total number of shards.
- prefix (str): prefix of the shard file name.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
- """
- # generate the shard name
- shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
- shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
-
- # save the shard
- save_state_dict(state_dict, str(shard_file_path), use_safetensors)
-
-
-# ========================================
-# Helper functions for loading state dict
-# ========================================
-
-
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
"""
Check whether the checkpoint has an index file.
@@ -608,7 +733,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
- return torch.load(checkpoint_file_path)
+ return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
@@ -654,5 +779,5 @@ def get_shard_filename(weights_name: str, idx: int):
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
- shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
+ shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file
diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py
index f8fd1c41a059..385b485b6016 100644
--- a/colossalai/cli/benchmark/models.py
+++ b/colossalai/cli/benchmark/models.py
@@ -1,6 +1,6 @@
import torch
-import colossalai.nn as col_nn
+import colossalai.legacy.nn as col_nn
class MLP(torch.nn.Module):
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 1dfd261d5d01..623160003767 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -94,17 +94,23 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
return np.unravel_index(rank, shape)
@staticmethod
- def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int:
+ def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
"""Convert a coordinate to a rank.
+ mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
+ with wrap, index out of range would be wrapped around.
+ For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
+ mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns:
int: Rank of the coordinate.
"""
- return np.ravel_multi_index(coord, shape)
+
+ assert mode in ["raise", "wrap", "clip"]
+ return np.ravel_multi_index(coord, shape, mode)
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py
index 003f0cdd91b6..7186f052ecec 100644
--- a/colossalai/context/parallel_context.py
+++ b/colossalai/context/parallel_context.py
@@ -15,8 +15,8 @@
from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger
-from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py
index 4c05028041ce..ba601d0bf61a 100644
--- a/colossalai/context/process_group_initializer/initializer_1d.py
+++ b/colossalai/context/process_group_initializer/initializer_1d.py
@@ -2,8 +2,9 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist
+
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py
index 7fbe3be5901f..999cd5f0cfc6 100644
--- a/colossalai/context/process_group_initializer/initializer_2d.py
+++ b/colossalai/context/process_group_initializer/initializer_2d.py
@@ -3,7 +3,7 @@
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py
index 6b6fdc5d715c..b92ae2eec07e 100644
--- a/colossalai/context/process_group_initializer/initializer_2p5d.py
+++ b/colossalai/context/process_group_initializer/initializer_2p5d.py
@@ -4,9 +4,10 @@
import math
import torch.distributed as dist
+
from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py
index 1ed8eec86efc..6bca05ad7d5f 100644
--- a/colossalai/context/process_group_initializer/initializer_3d.py
+++ b/colossalai/context/process_group_initializer/initializer_3d.py
@@ -6,7 +6,7 @@
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/context/process_group_initializer/initializer_data.py
index 9715ebff7f00..b9dec4541dad 100644
--- a/colossalai/context/process_group_initializer/initializer_data.py
+++ b/colossalai/context/process_group_initializer/initializer_data.py
@@ -3,7 +3,7 @@
from torch import distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py
index 99b9cc0d4edc..614ba372fbcc 100644
--- a/colossalai/context/process_group_initializer/initializer_model.py
+++ b/colossalai/context/process_group_initializer/initializer_model.py
@@ -2,9 +2,11 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
+
from ..parallel_mode import ParallelMode
+from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py
index 0ddb52f63e22..e093333ad18a 100644
--- a/colossalai/context/process_group_initializer/initializer_pipeline.py
+++ b/colossalai/context/process_group_initializer/initializer_pipeline.py
@@ -3,7 +3,7 @@
from torch import distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py
index 251a2940778a..a6e26b6bcaa9 100644
--- a/colossalai/context/process_group_initializer/initializer_sequence.py
+++ b/colossalai/context/process_group_initializer/initializer_sequence.py
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor
diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/context/process_group_initializer/initializer_tensor.py
index d2b5be9cfffb..3be89e52a812 100644
--- a/colossalai/context/process_group_initializer/initializer_tensor.py
+++ b/colossalai/context/process_group_initializer/initializer_tensor.py
@@ -3,9 +3,10 @@
import torch.distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
+
from ..parallel_mode import ParallelMode
+from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index 267c4529eb95..f41af1161be1 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -59,7 +59,7 @@ def __init__(self,
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \
"Only one of mesh_shape and logical_mesh_id can be specified." \
- "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
+ "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
if logical_mesh_id is None:
self._mesh_shape = mesh_shape
@@ -74,7 +74,7 @@ def __init__(self,
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
- "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
+ "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
@@ -118,7 +118,7 @@ def __init__(self,
self._global_rank_of_current_process = None
self._is_initialized = False
- # attribute used to inidicate whether this objectd
+ # attribute used to indicate whether this object
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
@@ -395,7 +395,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
Example:
```python
- sphysical_mesh_id = torch.arange(0, 16)
+ physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# logical mesh will look like
@@ -438,7 +438,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
- # if this dimension is not initailized yet,
+ # if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
processes_in_the_same_process_group[dim] = []
@@ -447,7 +447,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
# replace the local rank in the given dimension with the
- # lcoal rank of the current process iterated
+ # local rank of the current process iterated
process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)
diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md
new file mode 100644
index 000000000000..9a965dc982a4
--- /dev/null
+++ b/colossalai/inference/README.md
@@ -0,0 +1,117 @@
+# 🚀 Colossal-Inference
+
+## Table of contents
+
+## Introduction
+
+`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
+
+## Design
+
+Colossal Inference is composed of two main components:
+
+1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
+2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
+ 1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
+ 2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
+3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
+ 1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
+ 2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
+ 3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.
+
+## Pipeline of inference:
+
+In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
+
+
+
+## Roadmap of our implementation
+
+- [x] Design cache manager and batch infer state
+- [x] Design TpInference engine to integrates with `Shardformer`
+- [x] Register corresponding high-performance `kernel` and `ops`
+- [x] Design policies and forwards (e.g. `Llama` and `Bloom`)
+ - [x] policy
+ - [x] context forward
+ - [x] token forward
+- [ ] Replace the kernels with `faster-transformer` in token-forward stage
+- [ ] Support all models
+ - [x] Llama
+ - [x] Bloom
+ - [ ] Chatglm2
+- [ ] Benchmarking for all models
+
+## Get started
+
+### Installation
+
+```bash
+pip install -e .
+```
+
+### Requirements
+
+dependencies
+
+```bash
+pytorch= 1.13.1 (gpu)
+cuda>= 11.6
+transformers= 4.30.2
+triton==2.0.0.dev20221202
+# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch
+vllm
+# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
+flash-attention
+```
+
+### Docker
+
+You can use docker run to use docker container to set-up environment
+
+```
+# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support
+docker pull hpcaitech/colossalai-inference:v2
+docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
+
+```
+
+### Dive into fast-inference!
+
+example files are in
+
+```bash
+cd colossalai.examples
+python xx
+```
+
+## Performance
+
+### environment:
+
+We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`.
+
+For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future):
+
+### Single GPU Performance:
+
+Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned.
+
+#### Llama
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
+| colossal-inference | 326.4 | 582.72 | 816.64 |
+
+
+
+### Bloom
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
+| colossal-inference | 323.28 | 538.52 | 611.64 |
+
+
+
+The results of more models are coming soon!
diff --git a/tests/test_layers/test_1d/checks_1d/__init__.py b/colossalai/inference/__init__.py
similarity index 100%
rename from tests/test_layers/test_1d/checks_1d/__init__.py
rename to colossalai/inference/__init__.py
diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py
new file mode 100644
index 000000000000..e467b4c73e6b
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/__init__.py
@@ -0,0 +1,4 @@
+from .engine import TPInferEngine
+from .kvcache_manager import MemoryManager
+
+__all__ = ['MemoryManager', 'TPInferEngine']
diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py
new file mode 100644
index 000000000000..2bff9317283e
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/batch_infer_state.py
@@ -0,0 +1,55 @@
+# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+
+from .kvcache_manager import MemoryManager
+
+
+@dataclass
+class BatchInferState:
+ r"""
+ Information to be passed and used for a batch of inputs during
+ a single model forward
+ """
+ batch_size: int
+ max_len_in_batch: int
+
+ cache_manager: MemoryManager = None
+
+ block_loc: torch.Tensor = None
+ start_loc: torch.Tensor = None
+ seq_len: torch.Tensor = None
+ past_key_values_len: int = None
+
+ is_context_stage: bool = False
+ context_mem_index: torch.Tensor = None
+ decode_is_contiguous: bool = None
+ decode_mem_start: int = None
+ decode_mem_end: int = None
+ decode_mem_index: torch.Tensor = None
+ decode_layer_id: int = None
+
+ device: torch.device = torch.device('cuda')
+
+ @property
+ def total_token_num(self):
+ # return self.batch_size * self.max_len_in_batch
+ assert self.seq_len is not None and self.seq_len.size(0) > 0
+ return int(torch.sum(self.seq_len))
+
+ def set_cache_manager(self, manager: MemoryManager):
+ self.cache_manager = manager
+
+ @staticmethod
+ def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
+ alloc_mem_index: torch.Tensor):
+ """ in-place update block loc mapping based on the sequence length of the inputs in current bath"""
+ start_index = 0
+ seq_len_numpy = seq_len.cpu().numpy()
+ for i, cur_seq_len in enumerate(seq_len_numpy):
+ b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
+ cur_seq_len]
+ start_index += cur_seq_len
+ return
diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py
new file mode 100644
index 000000000000..a5a55702ade0
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/engine.py
@@ -0,0 +1,294 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from transformers import BloomForCausalLM, LlamaForCausalLM
+from transformers.generation import GenerationConfig
+from transformers.generation.stopping_criteria import StoppingCriteriaList
+from transformers.tokenization_utils_base import BatchEncoding
+
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.auto_policy import get_autopolicy
+
+from .batch_infer_state import BatchInferState
+from .kvcache_manager import MemoryManager
+
+DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
+
+_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
+
+
+class TPInferEngine:
+ """Engine class for tensor parallel inference.
+
+ Args:
+ model (Module): original model, e.g. huggingface CausalLM
+ shard_config (ShardConfig): The config for sharding original model
+ max_batch_size (int): maximum batch size
+ max_input_len (int): maximum input length of sequence
+ max_output_len (int): maximum output length of output tokens
+ dtype (torch.dtype): datatype used to init KV cache space
+ device (str): device the KV cache of engine to be initialized on
+
+ Examples:
+ >>> # define model and shard config for your inference
+ >>> model = ...
+ >>> generate_kwargs = ...
+ >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
+ >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
+ >>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
+ """
+
+ def __init__(self,
+ model: nn.Module,
+ shard_config: ShardConfig,
+ max_batch_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ dtype: torch.dtype = torch.float16,
+ device: str = 'cuda') -> None:
+ self.max_batch_size = max_batch_size
+ self.max_input_len = max_input_len
+ self.max_output_len = max_output_len
+ self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)
+
+ # Constraints relatable with specs of devices and model
+ # This may change into an optional arg in the future
+ assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
+ assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint"
+
+ self.dtype = dtype
+
+ self.head_dim = model.config.hidden_size // model.config.num_attention_heads
+ self.head_num = model.config.num_attention_heads
+ self.layer_num = model.config.num_hidden_layers
+
+ self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
+ self.cache_manager = None
+
+ self.shard_config = shard_config
+ self.model = None
+ # optimize the original model by sharding with ShardFormer
+ self._optimize_model(model=model.to(device))
+
+ def _init_manager(self) -> None:
+ assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
+ assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
+ self.head_num //= self.tp_size # update sharded number of heads
+ self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
+ self.layer_num)
+
+ def _optimize_model(self, model: nn.Module) -> None:
+ """
+ Optimize the original model by sharding with ShardFormer.
+ In further generation, use the sharded model instead of original model.
+ """
+ # NOTE we will change to use an inference config later with additional attrs we want
+ assert self.shard_config.inference_only is True
+ shardformer = ShardFormer(shard_config=self.shard_config)
+ self._prepare_with_shard_config(shard_config=self.shard_config)
+ self._shard_model_by(shardformer, model)
+
+ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
+ """ Prepare the engine with a given ShardConfig.
+
+ Args:
+ shard_config (ShardConfig): shard config given to specify settings of the engine.
+ If not provided, a default ShardConfig with tp size 1 will be created.
+ """
+ self.tp_size = 1
+ if shard_config is None:
+ shard_config = ShardConfig(
+ tensor_parallel_process_group=None,
+ pipeline_stage_manager=None,
+ enable_tensor_parallelism=False,
+ enable_fused_normalization=False,
+ enable_all_optimization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=False,
+ inference_only=True,
+ )
+ else:
+ shard_config.inference_only = True
+ shard_config.pipeline_stage_manager = None
+ if shard_config.enable_tensor_parallelism:
+ self.tp_size = shard_config.tensor_parallel_size
+ self._init_manager()
+
+ return shard_config
+
+ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
+ """ Shard original model by the given ShardFormer and store the sharded model. """
+ assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
+ "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
+ model_name = model.__class__.__name__
+ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
+ policy = get_autopolicy(model, inference_only=True)
+ self.model, _ = shardformer.optimize(model, policy)
+ self.model = self.model.cuda()
+
+ @property
+ def supported_models(self) -> List[str]:
+ return _supported_models
+
+ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor:
+ """Generate token sequence.
+
+ Args:
+ input_tokens: could be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ Returns:
+ torch.Tensor: The returned sequence is given inputs + generated_tokens.
+ """
+ if isinstance(input_tokens, torch.Tensor):
+ input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool))
+ for t in input_tokens:
+ if torch.is_tensor(input_tokens[t]):
+ input_tokens[t] = input_tokens[t].cuda()
+ if 'max_new_tokens' not in generate_kwargs:
+ generate_kwargs.update(max_new_tokens=self.max_output_len)
+
+ return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
+
+ def prepare_batch_state(self, inputs) -> BatchInferState:
+ """
+ Create and prepare BatchInferState used for inference during model forwrad,
+ by processing each sequence of the given inputs.
+
+ Args:
+ inputs: should be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve
+ the actual length (e.g. number of tokens) of each input without attention mask
+ Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume
+ all the inputs in the batch has the maximum length l
+ Returns:
+ BatchInferState: the states for the current batch during inference
+ """
+ if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)):
+ raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state")
+
+ input_ids_list = None
+ attention_mask = None
+
+ if isinstance(inputs, (BatchEncoding, dict)):
+ input_ids_list = inputs['input_ids']
+ attention_mask = inputs['attention_mask']
+ else:
+ input_ids_list = inputs
+ if isinstance(input_ids_list[0], int): # for a single input
+ input_ids_list = [input_ids_list]
+ attention_mask = [attention_mask] if attention_mask is not None else attention_mask
+
+ batch_size = len(input_ids_list)
+
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
+ start_index = 0
+
+ max_len_in_batch = -1
+ if isinstance(inputs, (BatchEncoding, dict)):
+ for i, attn_mask in enumerate(attention_mask):
+ curr_seq_len = len(attn_mask)
+ # if isinstance(attn_mask, torch.Tensor):
+ # curr_seq_len = int(torch.sum(attn_mask))
+ # else:
+ # curr_seq_len = int(sum(attn_mask))
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ else:
+ length = max(len(input_id) for input_id in input_ids_list)
+ for i, input_ids in enumerate(input_ids_list):
+ curr_seq_len = length
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
+ batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
+ batch_infer_state.seq_len = seq_lengths.to('cuda')
+ batch_infer_state.start_loc = seq_start_indexes.to('cuda')
+ batch_infer_state.block_loc = block_loc
+ batch_infer_state.decode_layer_id = 0
+ batch_infer_state.past_key_values_len = 0
+ batch_infer_state.is_context_stage = True
+ batch_infer_state.set_cache_manager(self.cache_manager)
+ return batch_infer_state
+
+ @torch.no_grad()
+ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
+ """
+ Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate
+
+ Args:
+ inputs: should be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ """
+
+ # for testing, always use sharded model
+ assert self.model is not None, "sharded model does not exist"
+
+ batch_infer_state = self.prepare_batch_state(input_tokens)
+ assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"
+
+ # set BatchInferState for the current batch as attr to model
+ # NOTE this is not a preferable way to pass BatchInferState during inference
+ # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state)
+ # and pass BatchInferState via model forward
+ model = self.model
+ if isinstance(model, LlamaForCausalLM):
+ model = self.model.model
+ elif isinstance(model, BloomForCausalLM):
+ model = self.model.transformer
+ setattr(model, 'infer_state', batch_infer_state)
+
+ outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
+
+ # NOTE In future development, we're going to let the scheduler to handle the cache,
+ # instead of freeing space explicitly at the end of generation
+ self.cache_manager.free_all()
+
+ return outputs
+
+ # TODO might want to implement the func that generates output tokens by passing BatchInferState
+ # as an arg into model.forward.
+ # It requires rewriting model generate and replacing model forward.
+ @torch.no_grad()
+ def _generate_by_pass_infer_state(self,
+ input_tokens,
+ max_out_length: int,
+ generation_config: Optional[GenerationConfig] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ **model_kwargs) -> torch.Tensor:
+
+ raise NotImplementedError("generate by passing BatchInferState is not implemented.")
+
+ # might want to use in rewritten generate method: use after model.forward
+ # BatchInferState is created and kept during generation
+ # after each iter of model forward, we should update BatchInferState
+ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
+ batch_size = infer_state.batch_size
+ device = infer_state.start_loc.device
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
+ infer_state.seq_len += 1
+
+ # might want to create a sequence pool
+ # add a single request/sequence/input text at a time and record its length
+ # In other words, store the actual length of input tokens representing a single input text
+ # E.g. "Introduce landmarks in Beijing"
+ # => add request
+ # => record token length and other necessary information to be used
+ # => engine hold all these necessary information until `generate` (or other name) is called,
+ # => put information already recorded in batchinferstate and pass it to model forward
+ # => clear records in engine
+ def add_request():
+ raise NotImplementedError()
diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py
new file mode 100644
index 000000000000..274c01841279
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/kvcache_manager.py
@@ -0,0 +1,101 @@
+# Adapted from lightllm/common/mem_manager.py
+# of the ModelTC/lightllm GitHub repository
+# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
+
+import torch
+from transformers.utils import logging
+
+
+class MemoryManager:
+ r"""
+ Manage token block indexes and allocate physical memory for key and value cache
+
+ Args:
+ size: maximum token number used as the size of key and value buffer
+ dtype: data type of cached key and value
+ head_num: number of heads the memory manager is responsible for
+ head_dim: embedded size per head
+ layer_num: the number of layers in the model
+ device: device used to store the key and value cache
+ """
+
+ def __init__(self,
+ size: int,
+ dtype: torch.dtype,
+ head_num: int,
+ head_dim: int,
+ layer_num: int,
+ device: torch.device = torch.device('cuda')):
+ self.logger = logging.get_logger(__name__)
+ self.available_size = size
+ self.past_key_values_length = 0
+ self._init_mem_states(size, device)
+ self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
+
+ def _init_mem_states(self, size, device):
+ """ Initialize tensors used to manage memory states """
+ self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
+ self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
+ self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
+
+ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
+ """ Initialize key buffer and value buffer on specified device """
+ self.key_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+ self.value_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+
+ @torch.no_grad()
+ def alloc(self, required_size):
+ """ allocate space of required_size by providing indexes representing available physical spaces """
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} "
+ f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
+ select_index = self.indexes[select_index]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ return select_index
+
+ @torch.no_grad()
+ def alloc_contiguous(self, required_size):
+ """ allocate contiguous space of required_size """
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} "
+ f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ sum_size = len(self.mem_cum_sum)
+ loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
+ 1] + self.mem_state[0:sum_size -
+ required_size + 1]
+ can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
+ if can_used_loc.shape[0] == 0:
+ self.logger.info(f"No enough contiguous cache: required_size {required_size} "
+ f"left_size {self.available_size}")
+ return None
+ start_loc = can_used_loc[0]
+ select_index = self.indexes[start_loc:start_loc + required_size]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ start = start_loc.item()
+ end = start + required_size
+ return select_index, start, end
+
+ @torch.no_grad()
+ def free(self, free_index):
+ """ free memory by updating memory states based on given indexes """
+ self.available_size += free_index.shape[0]
+ self.mem_state[free_index] = 1
+
+ @torch.no_grad()
+ def free_all(self):
+ """ free all memory by updating memory states """
+ self.available_size = len(self.mem_state)
+ self.mem_state[:] = 1
+ self.past_key_values_length = 0
+ self.logger.info("freed all space of memory manager")
diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py
new file mode 100644
index 000000000000..7a98b033f37e
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/__init__.py
@@ -0,0 +1,4 @@
+from .bloom import BloomInferenceForwards
+from .llama import LlamaInferenceForwards
+
+__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']
diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py
new file mode 100644
index 000000000000..ba5eadc92be8
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/bloom.py
@@ -0,0 +1,519 @@
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+from transformers.models.bloom.modeling_bloom import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BloomAttention,
+ BloomBlock,
+ BloomForCausalLM,
+ BloomModel,
+ CausalLMOutputWithCrossAttentions,
+)
+from transformers.utils import logging
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
+
+
+def generate_alibi(n_head, dtype=torch.float16):
+ """
+ This method is adapted from `_generate_alibi` function
+ in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py`
+ of the ModelTC/lightllm GitHub repository.
+ This method is originally the `build_alibi_tensor` function
+ in `transformers/models/bloom/modeling_bloom.py`
+ of the huggingface/transformers GitHub repository.
+ """
+
+ def get_slopes_power_of_2(n):
+ start = 2**(-(2**-(math.log2(n) - 3)))
+ return [start * start**i for i in range(n)]
+
+ def get_slopes(n):
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2**math.floor(math.log2(n))
+ slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
+ slopes_double = get_slopes(2 * closest_power_of_2)
+ slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2]
+ return slopes_combined
+
+ slopes = get_slopes(n_head)
+ return torch.tensor(slopes, dtype=dtype)
+
+
+class BloomInferenceForwards:
+ """
+ This class serves a micro library for bloom inference forwards.
+ We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention,
+ as well as prepare_inputs_for_generation method for BloomForCausalLM.
+ For future improvement, we might want to skip replacing methods for BloomForCausalLM,
+ and call BloomModel.forward iteratively in TpInferEngine
+ """
+
+ @staticmethod
+ def bloom_model_forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ # still need to keep past_key_values to fit original forward flow
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # NOTE determine if BatchInferState is passed in via arg
+ # if not, get the attr binded to the model
+ # We might wantto remove setattr later
+ if infer_state is None:
+ assert hasattr(self, 'infer_state')
+ infer_state = self.infer_state
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ # if self.cache_manager.past_key_values_length > 0:
+ if infer_state.cache_manager.past_key_values_length > 0:
+ # update the past key values length in cache manager,
+ # NOTE use BatchInferState.past_key_values_length instead the one in cache manager
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # infer_state.cache_manager = self.cache_manager
+
+ if use_cache and seq_length != 1:
+ # prefill stage
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
+ infer_state.context_mem_index)
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model,
+ # or store to BatchInferState to prevent re-calculating
+ # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here
+ # alibi = generate_alibi(self.num_heads).contiguous().cuda()
+ tp_size = dist.get_world_size()
+ curr_tp_rank = dist.get_rank()
+ alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) *
+ self.num_heads].cuda()
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ # NOTE: currently our KV cache manager does not handle this condition
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ causal_mask,
+ layer_past,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ infer_state=infer_state,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # update indices of kv cache block
+ # NOT READY FOR PRIME TIME
+ # might want to remove this part, instead, better to pass the BatchInferState from model forward,
+ # and update these information in engine.generate after model foward called
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.decode_layer_id = 0
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents, # should always be (None, None, ..., None)
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ @staticmethod
+ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state)
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ batch_size, seq_length, vocab_size = shift_logits.shape
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
+ shift_labels.view(batch_size * seq_length))
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def bloom_for_causal_lm_prepare_inputs_for_generation(
+ self: BloomForCausalLM,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict:
+ # only last token for input_ids if past is not None
+ if past_key_values:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ # NOTE we won't use past key values here
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
+ # if past_key_values[0][0].shape[0] == input_ids.shape[0]:
+ # past_key_values = self._convert_to_bloom_cache(past_key_values)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update({
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ })
+ return model_inputs
+
+ @staticmethod
+ def bloom_block_forward(
+ self: BloomBlock,
+ hidden_states: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ infer_state=infer_state,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+ @staticmethod
+ def bloom_attention_forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+ batch_size, q_length, H, D_HEAD = query_layer.shape
+ k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+
+ mem_manager = infer_state.cache_manager
+ layer_id = infer_state.decode_layer_id
+
+ if layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_length # += 1
+
+ if infer_state.is_context_stage:
+ # context process
+ max_input_len = q_length
+ b_start_loc = infer_state.start_loc
+ b_seq_len = infer_state.seq_len[:batch_size]
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id])
+
+ # output = self.output[:batch_size*q_length, :, :]
+ output = torch.empty_like(q)
+
+ bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+ else:
+ # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD)
+ assert q_length == 1, "for non-context process, we only support q_length == 1"
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_v = infer_state.cache_manager.value_buffer[layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_k.copy_(k)
+ cache_v.copy_(v)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head]
+ copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id])
+
+ b_start_loc = infer_state.start_loc
+ b_loc = infer_state.block_loc
+ b_seq_len = infer_state.seq_len
+ output = torch.empty_like(q)
+ token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc,
+ b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi)
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+
+ # update layer id
+ infer_state.decode_layer_id += 1
+
+ # NOTE: always set present as none for now, instead of returning past key value to the next decoding,
+ # we create the past key value pair from the cache manager
+ present = None
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # dropout is not required here during inference
+ output_tensor = residual + output_tensor
+
+ outputs = (output_tensor, present)
+ assert output_attentions is False, "we do not support output_attentions at this time"
+
+ return outputs
diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py
new file mode 100644
index 000000000000..07b73a6f4ca6
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/llama.py
@@ -0,0 +1,361 @@
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import (
+ copy_kv_cache_to_dest,
+ llama_context_attn_fwd,
+ rotary_embedding_fwd,
+ token_attention_fwd,
+)
+
+try:
+ from vllm import layernorm_ops, pos_encoding_ops
+ rms_norm = layernorm_ops.rms_norm
+ rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
+ HAS_VLLM_KERNERL = True
+except:
+ print("fall back to original rotary_embedding_neox of huggingface")
+ print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
+ print(
+ "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
+ )
+ HAS_VLLM_KERNERL = False
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., :x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+ return
+
+
+class LlamaInferenceForwards:
+ """
+ This class holds forwards for llama inference.
+ We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
+ """
+
+ @staticmethod
+ def llama_model_forward(
+ self: LlamaModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+
+ batch_size = input_ids.shape[0] # input_ids.shape[0]
+
+ infer_state = self.infer_state
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ # NOT READY FOR PRIME TIME
+ # dummy but work, revise it
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ # past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # NOTE: differentiate with prefill stage
+ # block_loc require different value-assigning method for two different stage
+ if use_cache and seq_length != 1:
+ # NOTE assuem prefill stage
+ # allocate memory block
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
+ infer_state.context_mem_index)
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if infer_state.is_context_stage:
+
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1)
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device)
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
+ past_key_values_length)
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ infer_state.decode_layer_id = 0
+
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+ # NOTE: modify here for passing args to decoder layer
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ infer_state.decode_layer_id += 1
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ hidden_states = self.norm(hidden_states)
+ next_cache = next_decoder_cache if use_cache else None
+
+ # update indices
+ # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ @staticmethod
+ def llama_decoder_layer_forward(
+ self: LlamaDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ @staticmethod
+ def llama_flash_attn_kvcache_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
+ assert use_cache is True, "use_cache should be set to True using this llama attention"
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # NOTE might think about better way to handle transposed k and v
+ # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
+ # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+
+ # NOTE might want to revise
+ # need some way to record the length of past key values cache
+ # since we won't return past_key_value_cache right now
+ if infer_state.decode_layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_len # seq_len
+
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+ # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
+
+ rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+ rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+
+ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+ return
+
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+ key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
+ value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
+
+ if infer_state.is_context_stage:
+ # first token generation
+
+ # copy key and value calculated in current step to memory manager
+ _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index,
+ infer_state.cache_manager)
+
+ attn_output = torch.empty_like(query_states)
+
+ llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc,
+ infer_state.seq_len, infer_state.cache_manager.past_key_values_length)
+ else:
+
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_k.copy_(key_states)
+ cache_v.copy_(value_states)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states,
+ infer_state.decode_mem_index, infer_state.cache_manager)
+
+ # second token and follows
+ # kv = torch.stack((key_states, value_states), dim=2)
+ # (batch_size, seqlen, nheads, headdim)
+ attn_output = torch.empty_like(query_states)
+
+ token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output,
+ infer_state.block_loc, infer_state.start_loc, infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length)
+
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ # return past_key_value as None
+ return attn_output, None, None
+
+
+def get_llama_vllm_rmsnorm_forward():
+
+ if HAS_VLLM_KERNERL:
+
+ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ x = hidden_states
+ out = torch.empty_like(x)
+ rms_norm(
+ out,
+ x,
+ self.weight.data,
+ self.variance_epsilon,
+ )
+
+ return out
+
+ return _vllm_rmsnorm_forward
+ else:
+ return None
diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py
new file mode 100644
index 000000000000..48f8db62c32a
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/__init__.py
@@ -0,0 +1,4 @@
+from .bloom import BloomModelInferPolicy
+from .llama import LlamaModelInferPolicy
+
+__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']
diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py
new file mode 100644
index 000000000000..cae43aa20421
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/bloom.py
@@ -0,0 +1,66 @@
+from functools import partial
+
+import torch
+from torch.nn import LayerNorm
+
+from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
+
+from ..modeling.bloom import BloomInferenceForwards
+
+try:
+ from colossalai.kernel.triton import layer_norm
+ HAS_TRITON_NORM = True
+except:
+ print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
+ HAS_TRITON_NORM = False
+
+
+def get_triton_layernorm_forward():
+ if HAS_TRITON_NORM:
+
+ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
+ return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)
+
+ return _triton_layernorm_forward
+ else:
+ return None
+
+
+class BloomModelInferPolicy(BloomForCausalLMPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
+ policy = super().module_policy()
+ # NOTE set inference mode to shard config
+ self.shard_config._infer()
+
+ method_replacement = {
+ 'forward': BloomInferenceForwards.bloom_for_causal_lm_forward,
+ 'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
+ }
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=BloomForCausalLM)
+
+ method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
+
+ method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
+
+ method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=BloomAttention)
+
+ if HAS_TRITON_NORM:
+ infer_method = get_triton_layernorm_forward()
+ method_replacement = {'forward': partial(infer_method)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LayerNorm)
+
+ return policy
diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py
new file mode 100644
index 000000000000..4844415d612c
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/llama.py
@@ -0,0 +1,68 @@
+from functools import partial
+
+import torch
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
+
+# import colossalai
+from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+
+from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
+
+try:
+ from colossalai.kernel.triton import rmsnorm_forward
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+def get_triton_rmsnorm_forward():
+ if HAS_TRITON_RMSNORM:
+
+ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
+
+ return _triton_rmsnorm_forward
+ else:
+ return None
+
+
+class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ self.shard_config._infer()
+
+ infer_forward = LlamaInferenceForwards.llama_model_forward
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
+
+ infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LlamaDecoderLayer)
+
+ infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LlamaAttention)
+
+ infer_forward = None
+ if HAS_TRITON_RMSNORM:
+ infer_forward = get_triton_rmsnorm_forward()
+ else:
+ # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
+ infer_forward = get_llama_vllm_rmsnorm_forward()
+
+ if infer_forward is not None:
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LlamaRMSNorm)
+
+ return policy
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index dc0df0517508..a1694e059fb4 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -17,13 +17,13 @@
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc
-from colossalai.engine import Engine
-from colossalai.engine.gradient_accumulation import accumulate_gradient
-from colossalai.engine.schedule import (
+from colossalai.legacy.builder.builder import build_gradient_handler
+from colossalai.legacy.engine import Engine
+from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
+from colossalai.legacy.engine.schedule import (
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py
index 8c658e375146..1c3199fc1aff 100644
--- a/colossalai/interface/__init__.py
+++ b/colossalai/interface/__init__.py
@@ -1,4 +1,4 @@
-from .model import ModelWrapper
+from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper
-__all__ = ['OptimizerWrapper', 'ModelWrapper']
+__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py
index a067d7671ce7..7b3d9435d255 100644
--- a/colossalai/interface/model.py
+++ b/colossalai/interface/model.py
@@ -23,3 +23,14 @@ def unwrap(self):
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
+
+
+class AMPModelMixin:
+ """This mixin class defines the interface for AMP training.
+ """
+
+ def update_master_params(self):
+ """
+ Update the master parameters for AMP training.
+ """
+ pass
diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py
index e20c08b051ed..8eb4e0c880a0 100644
--- a/colossalai/kernel/jit/option.py
+++ b/colossalai/kernel/jit/option.py
@@ -1,6 +1,6 @@
import torch
-from colossalai.nn.layer.colossalai_layer import Embedding, Linear
+from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train
diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py
new file mode 100644
index 000000000000..5840ad2918be
--- /dev/null
+++ b/colossalai/kernel/triton/__init__.py
@@ -0,0 +1,12 @@
+from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
+from .copy_kv_cache_dest import copy_kv_cache_to_dest
+from .fused_layernorm import layer_norm
+from .rms_norm import rmsnorm_forward
+from .rotary_embedding_kernel import rotary_embedding_fwd
+from .softmax import softmax
+from .token_attention_kernel import token_attention_fwd
+
+__all__ = [
+ "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
+ "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
+]
diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py
new file mode 100644
index 000000000000..38db2048c6a4
--- /dev/null
+++ b/colossalai/kernel/triton/context_attention.py
@@ -0,0 +1,184 @@
+import torch
+import math
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+
+if HAS_TRITON:
+ '''
+ this function is modified from
+ https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
+ '''
+ @triton.jit
+ def _context_flash_attention_kernel(
+ Q, K, V, sm_scale,
+ B_Start_Loc, B_Seqlen,
+ TMP,
+ alibi_ptr,
+ Out,
+ stride_qbs, stride_qh, stride_qd,
+ stride_kbs, stride_kh, stride_kd,
+ stride_vbs, stride_vh, stride_vd,
+ stride_obs, stride_oh, stride_od,
+ stride_tmp_b, stride_tmp_h, stride_tmp_s,
+ # suggtest set-up 64, 128, 256, 512
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ ):
+
+ batch_id = tl.program_id(0)
+ cur_head = tl.program_id(1)
+ start_m = tl.program_id(2)
+
+ # initialize offsets
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+
+ # get batch info
+ cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
+ cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
+ block_start_loc = BLOCK_M * start_m
+
+ load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
+ q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
+
+ k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
+ v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
+ t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
+
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+ if alibi_ptr is not None:
+ alibi_m = tl.load(alibi_ptr + cur_head)
+
+ block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
+
+ for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
+ mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
+
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k)
+ qk *= sm_scale
+
+ if alibi_ptr is not None:
+ alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
+ qk -= alibi_loc * alibi_m
+
+ qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
+
+ m_ij = tl.max(qk, 1)
+ p = tl.exp(qk - m_ij[:, None])
+ l_ij = tl.sum(p, 1)
+ # -- update m_i and l_i
+ m_i_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_i_new)
+ beta = tl.exp(m_ij - m_i_new)
+ l_i_new = alpha * l_i + beta * l_ij
+ # -- update output accumulator --
+ # scale p
+ p_scale = beta / l_i_new
+ p = p * p_scale[:, None]
+ # scale acc
+ acc_scale = l_i / l_i_new * alpha
+ tl.store(t_ptrs, acc_scale)
+ acc_scale = tl.load(t_ptrs)
+ acc = acc * acc_scale[:, None]
+ # update acc
+ v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
+ mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
+
+ p = p.to(v.dtype)
+ acc += tl.dot(p, v)
+ # update m_i and l_i
+ l_i = l_i_new
+ m_i = m_i_new
+
+ off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
+ out_ptrs = Out + off_o
+ tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
+ return
+
+
+ @torch.no_grad()
+ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
+ BLOCK = 128
+ # shape constraints
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert Lq == Lk, "context process only supports equal query, key, value length"
+ assert Lk == Lv, "context process only supports equal query, key, value length"
+ assert Lk in {16, 32, 64, 128}
+
+ sm_scale = 1.0 / math.sqrt(Lk)
+ batch, head = b_seq_len.shape[0], q.shape[1]
+
+ grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
+
+ num_warps = 4 if Lk <= 64 else 8
+
+ tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
+
+ _context_flash_attention_kernel[grid](
+ q, k, v, sm_scale,
+ b_start_loc, b_seq_len,
+ tmp,
+ alibi,
+ o,
+ q.stride(0), q.stride(1), q.stride(2),
+ k.stride(0), k.stride(1), k.stride(2),
+ v.stride(0), v.stride(1), v.stride(2),
+ o.stride(0), o.stride(1), o.stride(2),
+ tmp.stride(0), tmp.stride(1), tmp.stride(2),
+ # manually setting this blcok num, we can use tuning config to futher speed-up
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
+
+ @torch.no_grad()
+ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
+ BLOCK = 128
+ # shape constraints
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert Lq == Lk, "context process only supports equal query, key, value length"
+ assert Lk == Lv, "context process only supports equal query, key, value length"
+ assert Lk in {16, 32, 64, 128}
+
+ sm_scale = 1.0 / math.sqrt(Lk)
+ batch, head = b_seq_len.shape[0], q.shape[1]
+
+ grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
+
+ tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
+ num_warps = 4 if Lk <= 64 else 8
+ # num_warps = 4
+ _context_flash_attention_kernel[grid](
+ q, k, v, sm_scale, b_start_loc, b_seq_len,
+ tmp,
+ None,
+ o,
+ q.stride(0), q.stride(1), q.stride(2),
+ k.stride(0), k.stride(1), k.stride(2),
+ v.stride(0), v.stride(1), v.stride(2),
+ o.stride(0), o.stride(1), o.stride(2),
+ tmp.stride(0), tmp.stride(1), tmp.stride(2),
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
\ No newline at end of file
diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py
new file mode 100644
index 000000000000..c1eaa8a10ed1
--- /dev/null
+++ b/colossalai/kernel/triton/copy_kv_cache_dest.py
@@ -0,0 +1,69 @@
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ @triton.jit
+ def _fwd_copy_kv_cache_dest(
+ kv_cache_ptr, dest_index_ptr,
+ out,
+ stride_k_bs,
+ stride_k_h,
+ stride_k_d,
+ stride_o_bs,
+ stride_o_h,
+ stride_o_d,
+ head_num,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_HEAD: tl.constexpr
+ ):
+ cur_index = tl.program_id(0)
+ offs_h = tl.arange(0, BLOCK_HEAD)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+
+ dest_index = tl.load(dest_index_ptr + cur_index)
+
+ cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
+ k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
+
+ o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
+ o_ptrs = out + dest_index * stride_o_bs + o_offsets
+
+ k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
+ tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
+ return
+
+
+ @torch.no_grad()
+ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
+ seq_len = dest_index_ptr.shape[0]
+ head_num = k_ptr.shape[1]
+ head_dim = k_ptr.shape[2]
+ assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
+ assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
+
+ num_warps = 2
+
+ _fwd_copy_kv_cache_dest[(seq_len,)](
+ k_ptr, dest_index_ptr, out,
+ k_ptr.stride(0),
+ k_ptr.stride(1),
+ k_ptr.stride(2),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ head_num,
+ BLOCK_DMODEL=head_dim,
+ BLOCK_HEAD=triton.next_power_of_2(head_num),
+ num_warps=num_warps,
+ num_stages=2,
+ )
+ return
+
+
diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/fused_layernorm.py
new file mode 100644
index 000000000000..99800acfbb92
--- /dev/null
+++ b/colossalai/kernel/triton/fused_layernorm.py
@@ -0,0 +1,83 @@
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ # CREDITS: These functions are adapted from the Triton tutorial
+ # https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+
+ @triton.jit
+ def _layer_norm_fwd_fused(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ B, # pointer to the biases
+ stride, # how much to increase the pointer when moving by 1 row
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ # Compute mean
+ mean = 0
+ _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
+ _mean += a
+ mean = tl.sum(_mean, axis=0) / N
+ # Compute variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
+ x = tl.where(cols < N, x - mean, 0.)
+ _var += x * x
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # Normalize and apply linear transformation
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ w = tl.load(W + cols, mask=mask)
+ b = tl.load(B + cols, mask=mask)
+ x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
+ x_hat = (x - mean) * rstd
+ y = x_hat * w + b
+ # Write output
+ tl.store(Y + cols, y.to(tl.float16), mask=mask)
+
+ @torch.no_grad()
+ def layer_norm(x, weight, bias, eps):
+ # allocate output
+ y = torch.empty_like(x)
+ # reshape input data into 2D tensor
+ x_arg = x.reshape(-1, x.shape[-1])
+ M, N = x_arg.shape
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_SIZE:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ # enqueue kernel
+ _layer_norm_fwd_fused[(M,)](x_arg,
+ y,
+ weight,
+ bias,
+ x_arg.stride(0),
+ N,
+ eps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps)
+ return y
diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py
new file mode 100644
index 000000000000..1fb79115f8ce
--- /dev/null
+++ b/colossalai/kernel/triton/rms_norm.py
@@ -0,0 +1,72 @@
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+
+if HAS_TRITON:
+ '''
+ this kernel function is modified from
+ https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
+ '''
+ @triton.jit
+ def _rms_norm_fwd_fused(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ stride, # how much to increase the pointer when moving by 1 row
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ # Compute variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
+ _var += x * x
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # Normalize and apply linear transformation
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
+ x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
+ x_hat = x * rstd
+ y = x_hat * w
+ # Write output
+ tl.store(Y + cols, y.to(tl.float16), mask=mask)
+
+
+ def rmsnorm_forward(x, weight, eps):
+ # allocate output
+ y = torch.empty_like(x)
+ # reshape input data into 2D tensor
+ x_arg = x.view(-1, x.shape[-1])
+ M, N = x_arg.shape
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ # print("BLOCK_SIZE:", BLOCK_SIZE)
+ if N > BLOCK_SIZE:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ # print(BLOCK_SIZE, num_warps, "block_size, numwarps")
+ BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
+ num_warps = 8
+ # enqueue kernel
+ _rms_norm_fwd_fused[(M,)](x_arg, y, weight,
+ x_arg.stride(0), N, eps,
+ BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
+ return y
diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py
new file mode 100644
index 000000000000..d9d1b2bcf026
--- /dev/null
+++ b/colossalai/kernel/triton/rotary_embedding_kernel.py
@@ -0,0 +1,93 @@
+# Adapted from ModelTC https://github.com/ModelTC/lightllm
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _rotary_kernel(
+ q,
+ Cos,
+ Sin,
+ q_bs_stride,
+ q_h_stride,
+ q_d_stride,
+ cos_bs_stride,
+ cos_d_stride,
+ total_len,
+ HEAD_NUM: tl.constexpr,
+ BLOCK_HEAD: tl.constexpr,
+ BLOCK_SEQ: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+):
+ current_head_index = tl.program_id(0)
+ current_seq_index = tl.program_id(1)
+
+ current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
+ current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
+
+ dim_range0 = tl.arange(0, HEAD_DIM // 2)
+ dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
+
+ off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
+ None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
+ off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
+ None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride
+
+ off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
+
+ q0 = tl.load(q + off_q0,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
+ other=0.0)
+ q1 = tl.load(q + off_q1,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
+ other=0.0)
+
+ cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
+ sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
+
+ out0 = q0 * cos - q1 * sin
+ out1 = q0 * sin + q1 * cos
+
+ tl.store(q + off_q0,
+ out0,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
+ tl.store(q + off_q1,
+ out1,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
+
+ return
+
+
+@torch.no_grad()
+def rotary_embedding_fwd(q, cos, sin):
+ total_len = q.shape[0]
+ head_num = q.shape[1]
+ head_dim = q.shape[2]
+ assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
+ BLOCK_HEAD = 4
+ BLOCK_SEQ = 32
+ grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
+ if head_dim >= 128:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ _rotary_kernel[grid](
+ q,
+ cos,
+ sin,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ cos.stride(0),
+ cos.stride(1),
+ total_len,
+ HEAD_NUM=head_num,
+ BLOCK_HEAD=BLOCK_HEAD,
+ BLOCK_SEQ=BLOCK_SEQ,
+ HEAD_DIM=head_dim,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/self_attention_nofusion.py
similarity index 57%
rename from colossalai/kernel/triton/ops.py
rename to colossalai/kernel/triton/self_attention_nofusion.py
index 5e8d4ba3ec99..6ae54dcb0b38 100644
--- a/colossalai/kernel/triton/ops.py
+++ b/colossalai/kernel/triton/self_attention_nofusion.py
@@ -11,10 +11,11 @@
if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
- from .softmax_kernel import softmax_kernel
+ from .softmax import softmax_kernel
- def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
- r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
+ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+ input_mask: torch.Tensor, scale: float):
+ r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
@@ -36,39 +37,49 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
# head_size * num_of_head
d_model = q.shape[-1] * q.shape[-2]
- score_output = torch.empty(
- (batches, H, M, N), device=q.device, dtype=q.dtype)
+ score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda meta: (
batches,
H,
- triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
- triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
- q, k, score_output,
- M, N, K,
- q.stride(0), q.stride(2), q.stride(1), q.stride(3),
- k.stride(0), k.stride(2), k.stride(3), k.stride(1),
- score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
+ q,
+ k,
+ score_output,
+ M,
+ N,
+ K,
+ q.stride(0),
+ q.stride(2),
+ q.stride(1),
+ q.stride(3),
+ k.stride(0),
+ k.stride(2),
+ k.stride(3),
+ k.stride(1),
+ score_output.stride(0),
+ score_output.stride(1),
+ score_output.stride(2),
+ score_output.stride(3),
scale=scale,
- # currently manually setting, later on we can use auto-tune config to match best setting
+ # currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)
-
- softmax_output = torch.empty(
- score_output.shape, device=score_output.device, dtype=score_output.dtype)
+
+ softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype)
score_output_shape = score_output.shape
score_output = score_output.view(-1, score_output.shape[-1])
n_rows, n_cols = score_output.shape
if n_rows <= 350000:
-
+
block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
@@ -78,37 +89,39 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
else:
num_warps = 4
- softmax_kernel[(n_rows, )](
+ softmax_kernel[(n_rows,)](
softmax_output,
score_output,
score_output.stride(0),
n_cols,
- mask_ptr = input_mask,
+ mask_ptr=input_mask,
num_warps=num_warps,
BLOCK_SIZE=block_size,
)
else:
- #TODO: change softmax kernel functions to make it suitable for large size dimension
+ # NOTE: change softmax kernel functions to make it suitable for large size dimension
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
softmax_output = softmax_output.view(*score_output_shape)
batches, H, M, K = softmax_output.shape
N = v.shape[-1]
- output = torch.empty(
- (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
+ output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
grid = lambda meta: (
batches,
H,
- triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
- triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
- softmax_output, v, output,
- M, N, K,
+ softmax_output,
+ v,
+ output,
+ M,
+ N,
+ K,
softmax_output.stride(0),
softmax_output.stride(1),
softmax_output.stride(2),
@@ -129,7 +142,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
)
return output.view(batches, -1, d_model)
-
def self_attention_compute_using_triton(qkv,
input_mask,
layer_past,
@@ -152,58 +164,6 @@ def self_attention_compute_using_triton(qkv,
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
- data_output_triton = self_attention_forward_without_fusion(
- q, k, v, input_mask, scale)
+ data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale)
return data_output_triton
-
-
- def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
- if mask is not None:
- assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
- assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
-
- hidden_dim = input.shape[-1]
- output = torch.empty_like(input)
- input = input.view(-1, hidden_dim)
- if mask is not None:
- mask = mask.view(-1, hidden_dim)
- assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
-
- num_rows, num_cols = input.shape
- block_size = max(triton.next_power_of_2(num_cols), 2)
- num_warps = 16
- if block_size >= 4096:
- num_warps = 16
- elif block_size >= 2048:
- num_warps = 8
- else:
- num_warps = 4
-
- if num_rows <= 350000:
- grid = (num_rows,)
- softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
- else:
- grid = lambda meta: ()
-
- grid = lambda meta: (
- triton.cdiv(num_rows, meta["BLOCK_M"]),
- )
-
- BLOCK_M = 32
- if block_size >= 4096:
- BLOCK_M = 4
- elif block_size >= 2048:
- BLOCK_M = 8
-
- softmax_kernel_2[grid](output_ptr = output,
- input_ptr = input,
- row_stride = input.stride(0),
- n_rows = num_rows,
- n_cols = num_cols,
- mask_ptr = mask,
- # currently manually setting up size
- BLOCK_M = 32,
- BLOCK_SIZE = block_size)
-
- return output
\ No newline at end of file
diff --git a/colossalai/kernel/triton/softmax.py b/colossalai/kernel/triton/softmax.py
new file mode 100644
index 000000000000..c65adaf40dda
--- /dev/null
+++ b/colossalai/kernel/triton/softmax.py
@@ -0,0 +1,96 @@
+import torch
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ '''
+ softmax kernel is modified based on
+ https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
+ '''
+ @triton.jit
+ def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
+ r""" the kernel function for implementing softmax operator
+ Args:
+ output_ptr: the output after finishing softmax operation, (N, hidden_dim)
+ input_ptr: the tensor of input, shape should be (N, hidden_dim)
+ n_cols(tl.constexpr): the number of cols of input
+ BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
+ """
+ row_idx = tl.program_id(0)
+ row_start_ptr = input_ptr + row_idx * row_stride
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ input_ptrs = row_start_ptr + col_offsets
+ row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
+ row_minus_max = row - tl.max(row, axis=0)
+
+ if mask_ptr is not None:
+ # load mask into SRAM
+ mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
+ mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
+
+ # update
+ row_minus_max = row_minus_max + mask
+
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+ output_row_start_ptr = output_ptr + row_idx * row_stride
+ output_ptrs = output_row_start_ptr + col_offsets
+ # Write back output to DRAM
+ tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
+
+
+ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
+ if mask is not None:
+ assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
+ assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
+
+ hidden_dim = input.shape[-1]
+ output = torch.empty_like(input)
+ input = input.view(-1, hidden_dim)
+ if mask is not None:
+ mask = mask.view(-1, hidden_dim)
+ assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
+
+ num_rows, num_cols = input.shape
+ block_size = max(triton.next_power_of_2(num_cols), 2)
+ num_warps = 16
+ if block_size >= 4096:
+ num_warps = 16
+ elif block_size >= 2048:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ if num_rows <= 350000:
+ grid = (num_rows,)
+ softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
+ else:
+ grid = lambda meta: ()
+
+ grid = lambda meta: (
+ triton.cdiv(num_rows, meta["BLOCK_M"]),
+ )
+
+ BLOCK_M = 32
+ if block_size >= 4096:
+ BLOCK_M = 4
+ elif block_size >= 2048:
+ BLOCK_M = 8
+
+ softmax_kernel[grid](output_ptr = output,
+ input_ptr = input,
+ row_stride = input.stride(0),
+ n_rows = num_rows,
+ n_cols = num_cols,
+ mask_ptr = mask,
+ # currently manually setting up size
+ BLOCK_M = 32,
+ BLOCK_SIZE = block_size)
+
+ return output
\ No newline at end of file
diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py
deleted file mode 100644
index c215890badff..000000000000
--- a/colossalai/kernel/triton/softmax_kernel.py
+++ /dev/null
@@ -1,44 +0,0 @@
-try:
- import triton
- import triton.language as tl
- HAS_TRITON = True
-except ImportError:
- HAS_TRITON = False
- print("please install triton from https://github.com/openai/triton")
-
-if HAS_TRITON:
- '''
- softmax kernel is modified based on
- https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
- '''
- @triton.jit
- def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
- r""" the kernel function for implementing softmax operator
- Args:
- output_ptr: the output after finishing softmax operation, (N, hidden_dim)
- input_ptr: the tensor of input, shape should be (N, hidden_dim)
- n_cols(tl.constexpr): the number of cols of input
- BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
- """
- row_idx = tl.program_id(0)
- row_start_ptr = input_ptr + row_idx * row_stride
- col_offsets = tl.arange(0, BLOCK_SIZE)
- input_ptrs = row_start_ptr + col_offsets
- row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
- row_minus_max = row - tl.max(row, axis=0)
-
- if mask_ptr is not None:
- # load mask into SRAM
- mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
- mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
-
- # update
- row_minus_max = row_minus_max + mask
-
- numerator = tl.exp(row_minus_max)
- denominator = tl.sum(numerator, axis=0)
- softmax_output = numerator / denominator
- output_row_start_ptr = output_ptr + row_idx * row_stride
- output_ptrs = output_row_start_ptr + col_offsets
- # Write back output to DRAM
- tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
\ No newline at end of file
diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py
new file mode 100644
index 000000000000..c6b25f4abcec
--- /dev/null
+++ b/colossalai/kernel/triton/token_attention_kernel.py
@@ -0,0 +1,333 @@
+# Adapted from ModelTC https://github.com/ModelTC/lightllm
+
+import math
+
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+
+ @triton.jit
+ def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
+ attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride,
+ q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride,
+ attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+ start_n = tl.program_id(2)
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ current_batch_start_index = max_kv_cache_len - current_batch_seq_len
+ current_batch_end_index = max_kv_cache_len
+
+ off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
+
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ block_stard_index = start_n * BLOCK_N
+ block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
+
+ for start_mark in range(0, block_mask, 1):
+ q = tl.load(Q + off_q + start_mark)
+ offs_n_new = current_batch_start_index + offs_n
+ k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
+ mask=offs_n_new < current_batch_end_index,
+ other=0)
+ off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
+ k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
+ att_value = tl.sum(q[None, :] * k, 1)
+ att_value *= sm_scale
+ off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
+ tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
+ return
+
+ @triton.jit
+ def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen,
+ max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride,
+ q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride,
+ k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr,
+ BLOCK_N: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+ start_n = tl.program_id(2)
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ current_batch_start_index = max_kv_cache_len - current_batch_seq_len
+ current_batch_end_index = max_kv_cache_len
+
+ off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
+
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ block_stard_index = start_n * BLOCK_N
+ block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
+
+ for start_mark in range(0, block_mask, 1):
+ alibi_m = tl.load(alibi + current_head)
+ q = tl.load(Q + off_q + start_mark)
+ offs_n_new = current_batch_start_index + offs_n
+ k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
+ mask=offs_n_new < current_batch_end_index,
+ other=0)
+ off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
+ k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
+ att_value = tl.sum(q[None, :] * k, 1)
+ att_value *= sm_scale
+ att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
+ off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
+ tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
+ return
+
+ @torch.no_grad()
+ def token_attn_fwd_1(q,
+ k,
+ attn_out,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ alibi=None):
+ BLOCK = 32
+ # shape constraints
+ q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
+ assert q_head_dim == k_head_dim
+ assert k_head_dim in {16, 32, 64, 128}
+ sm_scale = 1.0 / (k_head_dim**0.5)
+
+ batch, head_num = kv_cache_loc.shape[0], q.shape[1]
+
+ grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
+
+ num_warps = 4 if k_head_dim <= 64 else 8
+ num_warps = 2
+
+ if alibi is not None:
+ _token_attn_1_alibi_kernel[grid](
+ q,
+ k,
+ sm_scale,
+ alibi,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ attn_out,
+ kv_cache_loc.stride(0),
+ kv_cache_loc.stride(1),
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ attn_out.stride(0),
+ attn_out.stride(1),
+ HEAD_DIM=k_head_dim,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ else:
+ _token_attn_1_kernel[grid](
+ q,
+ k,
+ sm_scale,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ attn_out,
+ kv_cache_loc.stride(0),
+ kv_cache_loc.stride(1),
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ attn_out.stride(0),
+ attn_out.stride(1),
+ HEAD_DIM=k_head_dim,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
+
+ @triton.jit
+ def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out,
+ logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride,
+ BLOCK_SIZE: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ row = tl.load(softmax_logics + current_head * logics_head_dim_stride +
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
+ mask=col_offsets < current_batch_seq_len,
+ other=-float('inf')).to(tl.float32)
+
+ row_minus_max = row - tl.max(row, axis=0)
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+
+ tl.store(softmax_prob_out + current_head * prob_head_dim_stride +
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
+ softmax_output,
+ mask=col_offsets < current_batch_seq_len)
+ return
+
+ @torch.no_grad()
+ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
+ BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
+ batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
+
+ num_warps = 4
+ if BLOCK_SIZE >= 2048:
+ num_warps = 8
+ if BLOCK_SIZE >= 4096:
+ num_warps = 16
+
+ _token_attn_softmax_fwd[(batch, head_num)](
+ softmax_logics,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ softmax_prob_out,
+ softmax_logics.stride(0),
+ softmax_logics.stride(1),
+ softmax_prob_out.stride(0),
+ softmax_prob_out.stride(1),
+ num_warps=num_warps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ return
+
+ @triton.jit
+ def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
+ kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride,
+ v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride,
+ attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr,
+ BLOCK_N: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, HEAD_DIM)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_start_index = max_kv_cache_len - current_batch_seq_len
+ current_batch_end_index = current_batch_seq_len
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
+ p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
+ v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
+
+ acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
+ for start_n in range(0, current_batch_seq_len, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride,
+ mask=(start_n + offs_n) < current_batch_seq_len,
+ other=0.0)
+ v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
+ mask=(start_n + offs_n) < current_batch_seq_len,
+ other=0.0)
+ v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride,
+ mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
+ other=0.0)
+ acc += tl.sum(p_value[:, None] * v_value, 0)
+
+ acc = acc.to(tl.float16)
+ off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride
+ out_ptrs = attn_out + off_o
+ tl.store(out_ptrs, acc)
+ return
+
+ @torch.no_grad()
+ def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len):
+ if triton.__version__ >= "2.1.0":
+ BLOCK = 128
+ else:
+ BLOCK = 64
+ batch, head = kv_cache_loc.shape[0], v.shape[1]
+ grid = (batch, head)
+ num_warps = 4
+ dim = v.shape[-1]
+
+ _token_attn_2_kernel[grid](
+ prob,
+ v,
+ attn_out,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ kv_cache_loc.stride(0),
+ kv_cache_loc.stride(1),
+ prob.stride(0),
+ prob.stride(1),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ attn_out.stride(0),
+ attn_out.stride(1),
+ attn_out.stride(2),
+ HEAD_DIM=dim,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
+
+ @torch.no_grad()
+ def token_attention_fwd(q,
+ k,
+ v,
+ attn_out,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seq_len,
+ max_len_in_batch,
+ alibi=None):
+ head_num = k.shape[1]
+ batch_size = kv_cache_seq_len.shape[0]
+ calcu_shape1 = (batch_size, head_num, k.shape[2])
+ total_token_num = k.shape[0]
+
+ att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
+
+ token_attn_fwd_1(q.view(calcu_shape1),
+ k,
+ att_m_tensor,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seq_len,
+ max_len_in_batch,
+ alibi=alibi)
+
+ prob = torch.empty_like(att_m_tensor)
+
+ token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
+ att_m_tensor = None
+ token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len,
+ max_len_in_batch)
+
+ prob = None
+
+ return
diff --git a/tests/test_layers/test_2d/checks_2d/__init__.py b/colossalai/legacy/__init__.py
similarity index 100%
rename from tests/test_layers/test_2d/checks_2d/__init__.py
rename to colossalai/legacy/__init__.py
diff --git a/colossalai/builder/__init__.py b/colossalai/legacy/builder/__init__.py
similarity index 100%
rename from colossalai/builder/__init__.py
rename to colossalai/legacy/builder/__init__.py
diff --git a/colossalai/builder/builder.py b/colossalai/legacy/builder/builder.py
similarity index 96%
rename from colossalai/builder/builder.py
rename to colossalai/legacy/builder/builder.py
index 4a907601327c..ff14f46dc61f 100644
--- a/colossalai/builder/builder.py
+++ b/colossalai/legacy/builder/builder.py
@@ -3,7 +3,7 @@
import inspect
-from colossalai.registry import *
+from colossalai.legacy.registry import *
def build_from_config(module, config: dict):
@@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
Returns:
- An object of :class:`colossalai.engine.BaseGradientHandler`
+ An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
"""
config_ = config.copy()
config_['model'] = model
diff --git a/colossalai/communication/__init__.py b/colossalai/legacy/communication/__init__.py
similarity index 53%
rename from colossalai/communication/__init__.py
rename to colossalai/legacy/communication/__init__.py
index 220481b7af15..88ad0487b785 100644
--- a/colossalai/communication/__init__.py
+++ b/colossalai/legacy/communication/__init__.py
@@ -1,9 +1,17 @@
-from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
-from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
- send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
- recv_forward, recv_backward)
+from .collective import all_gather, all_reduce, broadcast, reduce, reduce_scatter
+from .p2p import (
+ recv_backward,
+ recv_forward,
+ send_backward,
+ send_backward_recv_backward,
+ send_backward_recv_forward,
+ send_forward,
+ send_forward_backward_recv_forward_backward,
+ send_forward_recv_backward,
+ send_forward_recv_forward,
+)
from .ring import ring_forward
-from .utils import send_obj_meta, recv_obj_meta
+from .utils import recv_obj_meta, send_obj_meta
__all__ = [
'all_gather',
diff --git a/colossalai/communication/collective.py b/colossalai/legacy/communication/collective.py
similarity index 100%
rename from colossalai/communication/collective.py
rename to colossalai/legacy/communication/collective.py
diff --git a/colossalai/communication/p2p.py b/colossalai/legacy/communication/p2p.py
similarity index 100%
rename from colossalai/communication/p2p.py
rename to colossalai/legacy/communication/p2p.py
diff --git a/colossalai/communication/p2p_v2.py b/colossalai/legacy/communication/p2p_v2.py
similarity index 100%
rename from colossalai/communication/p2p_v2.py
rename to colossalai/legacy/communication/p2p_v2.py
diff --git a/colossalai/communication/ring.py b/colossalai/legacy/communication/ring.py
similarity index 100%
rename from colossalai/communication/ring.py
rename to colossalai/legacy/communication/ring.py
diff --git a/colossalai/communication/utils.py b/colossalai/legacy/communication/utils.py
similarity index 100%
rename from colossalai/communication/utils.py
rename to colossalai/legacy/communication/utils.py
diff --git a/colossalai/engine/__init__.py b/colossalai/legacy/engine/__init__.py
similarity index 100%
rename from colossalai/engine/__init__.py
rename to colossalai/legacy/engine/__init__.py
diff --git a/colossalai/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py
similarity index 97%
rename from colossalai/engine/_base_engine.py
rename to colossalai/legacy/engine/_base_engine.py
index db27ad0e8abe..9af4469f403f 100644
--- a/colossalai/engine/_base_engine.py
+++ b/colossalai/legacy/engine/_base_engine.py
@@ -8,11 +8,17 @@
from torch.nn import Module
from torch.nn.modules.loss import _Loss
-from colossalai.engine.gradient_handler import BaseGradientHandler
-from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
+from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.schedule import (
+ BaseSchedule,
+ InterleavedPipelineSchedule,
+ NonPipelineSchedule,
+ PipelineSchedule,
+)
from colossalai.logging import get_dist_logger
-from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
+
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/legacy/engine/gradient_accumulation/__init__.py
similarity index 94%
rename from colossalai/engine/gradient_accumulation/__init__.py
rename to colossalai/legacy/engine/gradient_accumulation/__init__.py
index 4cb6f4ad7384..670c26d06e55 100644
--- a/colossalai/engine/gradient_accumulation/__init__.py
+++ b/colossalai/legacy/engine/gradient_accumulation/__init__.py
@@ -4,7 +4,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
from ._gradient_accumulation import (
GradAccumDataloader,
@@ -33,7 +33,7 @@ def accumulate_gradient(model: nn.Module,
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
your dataloader object, would be called like iter(dataloader)
accumulate_size (int): the number of steps to accumulate gradients
- gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
+ gradient_handlers (List[:class:`colossalai.legacy.engine.BaseGradientHandler`]):
list of gradient handler objects. Default is None.
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
similarity index 98%
rename from colossalai/engine/gradient_accumulation/_gradient_accumulation.py
rename to colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
index cf66be1cd821..c466f7e2d03b 100644
--- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py
+++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
@@ -10,7 +10,7 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import conditional_context
@@ -262,7 +262,7 @@ class GradAccumGradientHandler:
before accumulation size is reached.
Args:
- grad_handler (:class:`colossalai.engine.BaseGradientHandler`):
+ grad_handler (:class:`colossalai.legacy.engine.BaseGradientHandler`):
Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
accumulate_size (int): The number of steps to accumulate gradients.
diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py
similarity index 100%
rename from colossalai/engine/gradient_handler/__init__.py
rename to colossalai/legacy/engine/gradient_handler/__init__.py
diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py
similarity index 100%
rename from colossalai/engine/gradient_handler/_base_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py
diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
similarity index 90%
rename from colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
index 5cc7169c5a9f..c5da2e55a0ed 100644
--- a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
@@ -1,7 +1,7 @@
+from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
-from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
similarity index 94%
rename from colossalai/engine/gradient_handler/_moe_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
index b499345d4e18..395d83da0478 100644
--- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
@@ -1,9 +1,9 @@
from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict
-from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
similarity index 97%
rename from colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
index 5b49a9c0360d..7d4d9d73afc8 100644
--- a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
@@ -7,7 +7,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
similarity index 90%
rename from colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
index ea4f0fbb1c71..41098ab39d0c 100644
--- a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
@@ -1,7 +1,7 @@
+from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
-from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
similarity index 92%
rename from colossalai/engine/gradient_handler/_zero_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
index 19fd1e97f86f..4ca7cd0b0702 100644
--- a/colossalai/engine/gradient_handler/_zero_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
@@ -1,4 +1,4 @@
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/legacy/engine/gradient_handler/utils.py
similarity index 100%
rename from colossalai/engine/gradient_handler/utils.py
rename to colossalai/legacy/engine/gradient_handler/utils.py
diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/legacy/engine/schedule/__init__.py
similarity index 100%
rename from colossalai/engine/schedule/__init__.py
rename to colossalai/legacy/engine/schedule/__init__.py
diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py
similarity index 98%
rename from colossalai/engine/schedule/_base_schedule.py
rename to colossalai/legacy/engine/schedule/_base_schedule.py
index a2d50041127a..7505a3eb20e3 100644
--- a/colossalai/engine/schedule/_base_schedule.py
+++ b/colossalai/legacy/engine/schedule/_base_schedule.py
@@ -95,7 +95,7 @@ def forward_backward_step(self,
"""The process function over a batch of dataset for training or evaluation.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
forward_only (bool): If True, the process won't include backward.
return_loss (bool, optional): If False, the loss won't be returned.
diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
similarity index 97%
rename from colossalai/engine/schedule/_non_pipeline_schedule.py
rename to colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
index b9239d928a7b..b67893c1a0bb 100644
--- a/colossalai/engine/schedule/_non_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
@@ -54,7 +54,7 @@ def forward_backward_step(self,
The returned labels and loss will None if :attr:`return_loss` is False.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
If True, the model is run for the forward pass, else back propagation will be executed.
diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
similarity index 98%
rename from colossalai/engine/schedule/_pipeline_schedule.py
rename to colossalai/legacy/engine/schedule/_pipeline_schedule.py
index 9fc301a26559..4571fd679e8c 100644
--- a/colossalai/engine/schedule/_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
@@ -6,7 +6,7 @@
import torch.cuda
-import colossalai.communication as comm
+import colossalai.legacy.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@@ -236,7 +236,7 @@ def _forward_step(self, engine, input_obj, return_tensors, return_output_label=T
Returns output tensor. This is a helper function and can be ignored by users.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
@@ -274,7 +274,7 @@ def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
This is a helper function and can be ignored by users.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
@@ -314,7 +314,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
@@ -518,7 +518,7 @@ def _forward_step(self,
Returns output tensor. This is a helper function and can be ignored by users.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
model_chunk_id (int): The id of model chunks.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
@@ -555,7 +555,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
communication between pipeline stages as needed.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
similarity index 96%
rename from colossalai/engine/schedule/_pipeline_schedule_v2.py
rename to colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
index 89e45c7aacec..385c615372f5 100644
--- a/colossalai/engine/schedule/_pipeline_schedule_v2.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
@@ -5,10 +5,10 @@
import torch.cuda
-import colossalai.communication.p2p_v2 as comm
-from colossalai import engine
+import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device
from ._pipeline_schedule import PipelineSchedule
@@ -60,7 +60,7 @@ def data_process_func(stage_output, dataloader_output):
"""
def forward_backward_step(self,
- engine: engine.Engine,
+ engine: Engine,
data_iter: Iterable,
forward_only=False,
return_loss=True,
@@ -69,7 +69,7 @@ def forward_backward_step(self,
Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
diff --git a/colossalai/legacy/nn/__init__.py b/colossalai/legacy/nn/__init__.py
new file mode 100644
index 000000000000..500162901905
--- /dev/null
+++ b/colossalai/legacy/nn/__init__.py
@@ -0,0 +1,4 @@
+from ._ops import *
+from .layer import *
+from .loss import *
+from .metric import *
diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/legacy/nn/_ops/__init__.py
similarity index 100%
rename from colossalai/nn/_ops/__init__.py
rename to colossalai/legacy/nn/_ops/__init__.py
diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py
similarity index 99%
rename from colossalai/nn/_ops/_utils.py
rename to colossalai/legacy/nn/_ops/_utils.py
index 24877bbb552f..131c2154771b 100644
--- a/colossalai/nn/_ops/_utils.py
+++ b/colossalai/legacy/nn/_ops/_utils.py
@@ -4,7 +4,7 @@
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.nn.layer.utils import divide
+from colossalai.legacy.nn.layer.utils import divide
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
GeneralTensor = Union[ColoTensor, torch.Tensor]
@@ -232,7 +232,7 @@ def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)
-### table wise embedding shard
+# table wise embedding shard
def _all_to_all_for_tablewise(x: torch.Tensor,
diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/legacy/nn/_ops/addmm.py
similarity index 100%
rename from colossalai/nn/_ops/addmm.py
rename to colossalai/legacy/nn/_ops/addmm.py
diff --git a/colossalai/nn/_ops/batch_norm.py b/colossalai/legacy/nn/_ops/batch_norm.py
similarity index 100%
rename from colossalai/nn/_ops/batch_norm.py
rename to colossalai/legacy/nn/_ops/batch_norm.py
diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/legacy/nn/_ops/element_wise.py
similarity index 100%
rename from colossalai/nn/_ops/element_wise.py
rename to colossalai/legacy/nn/_ops/element_wise.py
diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/legacy/nn/_ops/embedding.py
similarity index 98%
rename from colossalai/nn/_ops/embedding.py
rename to colossalai/legacy/nn/_ops/embedding.py
index a045f305b5dc..b145d1763380 100644
--- a/colossalai/nn/_ops/embedding.py
+++ b/colossalai/legacy/nn/_ops/embedding.py
@@ -1,8 +1,10 @@
-import torch.nn.functional as F
from typing import Optional
+
+import torch.nn.functional as F
+
+from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \
- ReplicaSpec
+
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/legacy/nn/_ops/embedding_bag.py
similarity index 97%
rename from colossalai/nn/_ops/embedding_bag.py
rename to colossalai/legacy/nn/_ops/embedding_bag.py
index 0026f579b6dc..9a656d5871a3 100644
--- a/colossalai/nn/_ops/embedding_bag.py
+++ b/colossalai/legacy/nn/_ops/embedding_bag.py
@@ -1,9 +1,11 @@
-import torch.nn.functional as F
from typing import Optional
+
+import torch.nn.functional as F
from torch import Tensor
+
+from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \
- ShardSpec, ReplicaSpec
+
from ._utils import GeneralTensor, convert_to_colo_tensor
diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/legacy/nn/_ops/layernorm.py
similarity index 92%
rename from colossalai/nn/_ops/layernorm.py
rename to colossalai/legacy/nn/_ops/layernorm.py
index 2b761b84e3ee..9960c5d48096 100644
--- a/colossalai/nn/_ops/layernorm.py
+++ b/colossalai/legacy/nn/_ops/layernorm.py
@@ -1,7 +1,10 @@
from typing import List, Optional
+
import torch.nn.functional as F
+
+from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
+
from ._utils import GeneralTensor, convert_to_colo_tensor
diff --git a/colossalai/nn/_ops/linear.py b/colossalai/legacy/nn/_ops/linear.py
similarity index 100%
rename from colossalai/nn/_ops/linear.py
rename to colossalai/legacy/nn/_ops/linear.py
diff --git a/colossalai/nn/_ops/loss.py b/colossalai/legacy/nn/_ops/loss.py
similarity index 96%
rename from colossalai/nn/_ops/loss.py
rename to colossalai/legacy/nn/_ops/loss.py
index 1e54f662859c..90efbfa36f2a 100644
--- a/colossalai/nn/_ops/loss.py
+++ b/colossalai/legacy/nn/_ops/loss.py
@@ -1,9 +1,12 @@
+from typing import Optional
+
import torch
import torch.nn.functional as F
-from typing import Optional
-from colossalai.tensor.op_wrapper import colo_op_impl
+
+from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
from colossalai.tensor import ColoTensor, ColoTensorSpec
-from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
+from colossalai.tensor.op_wrapper import colo_op_impl
+
from ._utils import GeneralTensor, convert_to_colo_tensor
diff --git a/colossalai/nn/_ops/view.py b/colossalai/legacy/nn/_ops/view.py
similarity index 100%
rename from colossalai/nn/_ops/view.py
rename to colossalai/legacy/nn/_ops/view.py
diff --git a/colossalai/legacy/nn/layer/__init__.py b/colossalai/legacy/nn/layer/__init__.py
new file mode 100644
index 000000000000..86961dd933a7
--- /dev/null
+++ b/colossalai/legacy/nn/layer/__init__.py
@@ -0,0 +1,9 @@
+from .colossalai_layer import *
+from .parallel_1d import *
+from .parallel_2d import *
+from .parallel_2p5d import *
+from .parallel_3d import *
+from .parallel_sequence import *
+from .utils import *
+from .vanilla import *
+from .wrapper import *
diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/legacy/nn/layer/base_layer.py
similarity index 100%
rename from colossalai/nn/layer/base_layer.py
rename to colossalai/legacy/nn/layer/base_layer.py
diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py
similarity index 97%
rename from colossalai/nn/layer/colossalai_layer/__init__.py
rename to colossalai/legacy/nn/layer/colossalai_layer/__init__.py
index 2ae1b07a75b2..ed743820ddbc 100644
--- a/colossalai/nn/layer/colossalai_layer/__init__.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py
@@ -1,7 +1,7 @@
-from ._utils import partition_batch
-from .dropout import Dropout
-from .embedding import Embedding, PatchEmbedding
-from .linear import Classifier, Linear
-from .normalization import LayerNorm
-
-__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']
+from ._utils import partition_batch
+from .dropout import Dropout
+from .embedding import Embedding, PatchEmbedding
+from .linear import Classifier, Linear
+from .normalization import LayerNorm
+
+__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']
diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py
similarity index 100%
rename from colossalai/nn/layer/colossalai_layer/_utils.py
rename to colossalai/legacy/nn/layer/colossalai_layer/_utils.py
diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py
similarity index 100%
rename from colossalai/nn/layer/colossalai_layer/dropout.py
rename to colossalai/legacy/nn/layer/colossalai_layer/dropout.py
diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py
similarity index 97%
rename from colossalai/nn/layer/colossalai_layer/embedding.py
rename to colossalai/legacy/nn/layer/colossalai_layer/embedding.py
index e5c9c46e0ff1..28bcb7ffefb0 100644
--- a/colossalai/nn/layer/colossalai_layer/embedding.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py
@@ -1,151 +1,152 @@
-import math
-from typing import Callable
-
-from colossalai.utils import get_current_device
-from torch import dtype, nn
-
-from ... import init as init
-from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
-from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
-from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D
-from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D
-from ..utils import get_tensor_parallel_mode
-from ..vanilla import VanillaPatchEmbedding
-from ._utils import ColossalaiModule
-
-_parallel_embedding = {
- '1d': Embedding1D,
- '2d': Embedding2D,
- '2.5d': Embedding2p5D,
- '3d': Embedding3D,
-}
-
-_vocab_parallel_embedding = {
- '1d': VocabParallelEmbedding1D,
- '2d': VocabParallelEmbedding2D,
- '2.5d': VocabParallelEmbedding2p5D,
- '3d': VocabParallelEmbedding3D
-}
-
-_parallel_patchembedding = {
- None: VanillaPatchEmbedding,
- '1d': PatchEmbedding1D,
- '2d': PatchEmbedding2D,
- '2.5d': PatchEmbedding2p5D,
- '3d': PatchEmbedding3D
-}
-
-
-class Embedding(ColossalaiModule):
- r"""Embedding for colossalai.
-
- Args:
- num_embeddings (int): number of embeddings.
- embedding_dim (int): dimension of embedding.
- padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
- therefore, the embedding vector at padding_idx is not updated during training,
- i.e. it remains as a fixed “pad”, defaults to None.
- dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
- weight_initializer (:class:`typing.Callable`, optional):
- he initializer of weight, defaults to normal initializer.
-
- The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
- ::
-
- max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
- renormalized to have norm max_norm. Note: this will modify weight in-place.
- norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
- scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
- of frequency of the words in the mini-batch. Default False.
- sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
-
- More details about ``args`` and ``kwargs`` could be found in
- `Embedding `_.
-
- More details about ``initializer`` please refer to
- `init `_
- """
-
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- padding_idx: int = None,
- dtype: dtype = None,
- weight_initializer: Callable = init.normal_(),
- vocab_parallel_limit: int = 2048,
- *args,
- **kwargs) -> None:
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel is None:
- embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
- **kwargs).to(dtype).to(get_current_device())
- weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
- elif num_embeddings <= vocab_parallel_limit:
- embed = _parallel_embedding[tensor_parallel](
- num_embeddings,
- embedding_dim,
- padding_idx=padding_idx,
- dtype=dtype,
- weight_initializer=weight_initializer,
- *args,
- **kwargs,
- )
- else:
- embed = _vocab_parallel_embedding[tensor_parallel](
- num_embeddings,
- embedding_dim,
- padding_idx=padding_idx,
- dtype=dtype,
- weight_initializer=weight_initializer,
- *args,
- **kwargs,
- )
- super().__init__(embed)
-
-
-class PatchEmbedding(ColossalaiModule):
- """2D Image to Patch Embedding.
-
- Args:
- img_size (int): image size.
- patch_size (int): patch size.
- in_chans (int): number of channels of input image.
- embed_size (int): size of embedding.
- dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
- flatten (bool, optional): whether to flatten output tensor, defaults to True.
- weight_initializer (:class:`typing.Callable`, optional):
- The initializer of weight, defaults to kaiming uniform initializer.
- bias_initializer (:class:`typing.Callable`, optional):
- The initializer of bias, defaults to xavier uniform initializer.
- position_embed_initializer (:class:`typing.Callable`, optional):
- The initializer of position embedding, defaults to zeros initializer.
-
- More details about ``initializer`` please refer to
- `init `_.
- """
-
- def __init__(
- self,
- img_size: int,
- patch_size: int,
- in_chans: int,
- embed_size: int,
- dtype: dtype = None,
- flatten: bool = True,
- weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
- bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
- position_embed_initializer: Callable = init.zeros_()
- ) -> None:
- tensor_parallel = get_tensor_parallel_mode()
- embed = _parallel_patchembedding[tensor_parallel](
- img_size,
- patch_size,
- in_chans,
- embed_size,
- dtype=dtype,
- flatten=flatten,
- weight_initializer=weight_initializer,
- bias_initializer=bias_initializer,
- position_embed_initializer=position_embed_initializer,
- )
- super().__init__(embed)
+import math
+from typing import Callable
+
+from torch import dtype, nn
+
+from colossalai.nn import init
+from colossalai.utils import get_current_device
+
+from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
+from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
+from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D
+from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D
+from ..utils import get_tensor_parallel_mode
+from ..vanilla import VanillaPatchEmbedding
+from ._utils import ColossalaiModule
+
+_parallel_embedding = {
+ '1d': Embedding1D,
+ '2d': Embedding2D,
+ '2.5d': Embedding2p5D,
+ '3d': Embedding3D,
+}
+
+_vocab_parallel_embedding = {
+ '1d': VocabParallelEmbedding1D,
+ '2d': VocabParallelEmbedding2D,
+ '2.5d': VocabParallelEmbedding2p5D,
+ '3d': VocabParallelEmbedding3D
+}
+
+_parallel_patchembedding = {
+ None: VanillaPatchEmbedding,
+ '1d': PatchEmbedding1D,
+ '2d': PatchEmbedding2D,
+ '2.5d': PatchEmbedding2p5D,
+ '3d': PatchEmbedding3D
+}
+
+
+class Embedding(ColossalaiModule):
+ r"""Embedding for colossalai.
+
+ Args:
+ num_embeddings (int): number of embeddings.
+ embedding_dim (int): dimension of embedding.
+ padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
+ therefore, the embedding vector at padding_idx is not updated during training,
+ i.e. it remains as a fixed “pad”, defaults to None.
+ dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
+ weight_initializer (:class:`typing.Callable`, optional):
+ he initializer of weight, defaults to normal initializer.
+
+ The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
+ ::
+
+ max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
+ renormalized to have norm max_norm. Note: this will modify weight in-place.
+ norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
+ scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
+ of frequency of the words in the mini-batch. Default False.
+ sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
+
+ More details about ``args`` and ``kwargs`` could be found in
+ `Embedding `_.
+
+ More details about ``initializer`` please refer to
+ `init `_
+ """
+
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: int = None,
+ dtype: dtype = None,
+ weight_initializer: Callable = init.normal_(),
+ vocab_parallel_limit: int = 2048,
+ *args,
+ **kwargs) -> None:
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel is None:
+ embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
+ **kwargs).to(dtype).to(get_current_device())
+ weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
+ elif num_embeddings <= vocab_parallel_limit:
+ embed = _parallel_embedding[tensor_parallel](
+ num_embeddings,
+ embedding_dim,
+ padding_idx=padding_idx,
+ dtype=dtype,
+ weight_initializer=weight_initializer,
+ *args,
+ **kwargs,
+ )
+ else:
+ embed = _vocab_parallel_embedding[tensor_parallel](
+ num_embeddings,
+ embedding_dim,
+ padding_idx=padding_idx,
+ dtype=dtype,
+ weight_initializer=weight_initializer,
+ *args,
+ **kwargs,
+ )
+ super().__init__(embed)
+
+
+class PatchEmbedding(ColossalaiModule):
+ """2D Image to Patch Embedding.
+
+ Args:
+ img_size (int): image size.
+ patch_size (int): patch size.
+ in_chans (int): number of channels of input image.
+ embed_size (int): size of embedding.
+ dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
+ flatten (bool, optional): whether to flatten output tensor, defaults to True.
+ weight_initializer (:class:`typing.Callable`, optional):
+ The initializer of weight, defaults to kaiming uniform initializer.
+ bias_initializer (:class:`typing.Callable`, optional):
+ The initializer of bias, defaults to xavier uniform initializer.
+ position_embed_initializer (:class:`typing.Callable`, optional):
+ The initializer of position embedding, defaults to zeros initializer.
+
+ More details about ``initializer`` please refer to
+ `init `_.
+ """
+
+ def __init__(
+ self,
+ img_size: int,
+ patch_size: int,
+ in_chans: int,
+ embed_size: int,
+ dtype: dtype = None,
+ flatten: bool = True,
+ weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
+ bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
+ position_embed_initializer: Callable = init.zeros_()
+ ) -> None:
+ tensor_parallel = get_tensor_parallel_mode()
+ embed = _parallel_patchembedding[tensor_parallel](
+ img_size,
+ patch_size,
+ in_chans,
+ embed_size,
+ dtype=dtype,
+ flatten=flatten,
+ weight_initializer=weight_initializer,
+ bias_initializer=bias_initializer,
+ position_embed_initializer=position_embed_initializer,
+ )
+ super().__init__(embed)
diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/legacy/nn/layer/colossalai_layer/linear.py
similarity index 99%
rename from colossalai/nn/layer/colossalai_layer/linear.py
rename to colossalai/legacy/nn/layer/colossalai_layer/linear.py
index 3e0c6e285c1c..c05ceb66ce25 100644
--- a/colossalai/nn/layer/colossalai_layer/linear.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/linear.py
@@ -4,9 +4,9 @@
from torch import dtype, nn
+from colossalai.nn import init
from colossalai.utils import get_current_device
-from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py
similarity index 97%
rename from colossalai/nn/layer/colossalai_layer/normalization.py
rename to colossalai/legacy/nn/layer/colossalai_layer/normalization.py
index 86861d30214a..f8e317e723f1 100644
--- a/colossalai/nn/layer/colossalai_layer/normalization.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py
@@ -1,41 +1,42 @@
-from colossalai.utils import get_current_device
-from torch import nn
-
-from ..parallel_1d import LayerNorm1D
-from ..parallel_2d import LayerNorm2D
-from ..parallel_2p5d import LayerNorm2p5D
-from ..parallel_3d import LayerNorm3D
-from ..utils import get_tensor_parallel_mode
-from ..vanilla import VanillaLayerNorm
-from ._utils import ColossalaiModule
-
-_parallel_layernorm = {
- None: VanillaLayerNorm,
- "1d": LayerNorm1D,
- "2d": LayerNorm2D,
- "2.5d": LayerNorm2p5D,
- "3d": LayerNorm3D,
-}
-
-
-class LayerNorm(ColossalaiModule):
- r"""Layer Normalization for colossalai.
-
- Args:
- normalized_shape (int): input shape from an expected input of size.
- :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
- \times \ldots \times \text{normalized_shape}[-1]]`
- If a single integer is used, it is treated as a singleton list, and this module will
- normalize over the last dimension which is expected to be of that specific size.
- eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
- bias (bool, optional): Whether to add a bias, defaults to ``True``.
- dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
- """
-
- def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel is None:
- norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
- else:
- norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
- super().__init__(norm)
+from torch import nn
+
+from colossalai.utils import get_current_device
+
+from ..parallel_1d import LayerNorm1D
+from ..parallel_2d import LayerNorm2D
+from ..parallel_2p5d import LayerNorm2p5D
+from ..parallel_3d import LayerNorm3D
+from ..utils import get_tensor_parallel_mode
+from ..vanilla import VanillaLayerNorm
+from ._utils import ColossalaiModule
+
+_parallel_layernorm = {
+ None: VanillaLayerNorm,
+ "1d": LayerNorm1D,
+ "2d": LayerNorm2D,
+ "2.5d": LayerNorm2p5D,
+ "3d": LayerNorm3D,
+}
+
+
+class LayerNorm(ColossalaiModule):
+ r"""Layer Normalization for colossalai.
+
+ Args:
+ normalized_shape (int): input shape from an expected input of size.
+ :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
+ \times \ldots \times \text{normalized_shape}[-1]]`
+ If a single integer is used, it is treated as a singleton list, and this module will
+ normalize over the last dimension which is expected to be of that specific size.
+ eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
+ bias (bool, optional): Whether to add a bias, defaults to ``True``.
+ dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
+ """
+
+ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel is None:
+ norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
+ else:
+ norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
+ super().__init__(norm)
diff --git a/colossalai/legacy/nn/layer/parallel_1d/__init__.py b/colossalai/legacy/nn/layer/parallel_1d/__init__.py
new file mode 100644
index 000000000000..9cffd4d339f5
--- /dev/null
+++ b/colossalai/legacy/nn/layer/parallel_1d/__init__.py
@@ -0,0 +1,17 @@
+from .layers import (
+ Classifier1D,
+ Dropout1D,
+ Embedding1D,
+ LayerNorm1D,
+ Linear1D,
+ Linear1D_Col,
+ Linear1D_Row,
+ PatchEmbedding1D,
+ VocabParallelClassifier1D,
+ VocabParallelEmbedding1D,
+)
+
+__all__ = [
+ 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
+ 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
+]
diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
similarity index 100%
rename from colossalai/nn/layer/parallel_1d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_1d/_operation.py
diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/legacy/nn/layer/parallel_1d/_utils.py
similarity index 99%
rename from colossalai/nn/layer/parallel_1d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_1d/_utils.py
index 1212d595635d..fddf4e73db51 100644
--- a/colossalai/nn/layer/parallel_1d/_utils.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/_utils.py
@@ -3,6 +3,7 @@
import torch
import torch.distributed as dist
+
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
@@ -124,7 +125,7 @@ def backward(ctx, grad_output):
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
-
+
Args:
input_: input matrix.
parallel_mode: parallel mode.
diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_1d/layers.py
rename to colossalai/legacy/nn/layer/parallel_1d/layers.py
index 406173a18c60..c0a169c1596f 100644
--- a/colossalai/nn/layer/parallel_1d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py
@@ -10,13 +10,13 @@
from torch import Tensor
from torch.nn.parameter import Parameter
-from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
+from colossalai.legacy.communication import broadcast
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/legacy/nn/layer/parallel_2d/__init__.py
similarity index 59%
rename from colossalai/nn/layer/parallel_2d/__init__.py
rename to colossalai/legacy/nn/layer/parallel_2d/__init__.py
index 5562d1a70036..9c65f3608710 100644
--- a/colossalai/nn/layer/parallel_2d/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/__init__.py
@@ -1,6 +1,13 @@
from ._operation import reduce_by_batch_2d, split_batch_2d
-from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
- VocabParallelEmbedding2D)
+from .layers import (
+ Classifier2D,
+ Embedding2D,
+ LayerNorm2D,
+ Linear2D,
+ PatchEmbedding2D,
+ VocabParallelClassifier2D,
+ VocabParallelEmbedding2D,
+)
__all__ = [
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
similarity index 98%
rename from colossalai/nn/layer/parallel_2d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_2d/_operation.py
index 306577dbd933..fa9b49bcf53f 100644
--- a/colossalai/nn/layer/parallel_2d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
@@ -2,13 +2,14 @@
import torch
import torch.distributed as dist
-from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
+
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce, reduce_scatter
+from colossalai.utils import get_current_device
def matmul_2d(
@@ -226,9 +227,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
@@ -351,9 +352,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
@@ -484,9 +485,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
diff --git a/colossalai/nn/layer/parallel_2d/_utils.py b/colossalai/legacy/nn/layer/parallel_2d/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_2d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_2d/_utils.py
diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_2d/layers.py
rename to colossalai/legacy/nn/layer/parallel_2d/layers.py
index f3a4d2bbbc32..b458d15c78e7 100644
--- a/colossalai/nn/layer/parallel_2d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py
@@ -5,21 +5,30 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from colossalai.communication import broadcast
+from torch import Tensor
+from torch.nn import Parameter
+
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication import broadcast
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
from colossalai.utils.cuda import get_current_device
-from torch import Tensor
-from torch.nn import Parameter
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
-from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
- reduce_scatter_tensor_2d, split_batch_2d)
+from ._operation import (
+ Matmul_AB_2D,
+ Matmul_ABT_2D,
+ add_bias_2d,
+ all_gather_tensor_2d,
+ classifier_2d,
+ layernorm_2d,
+ reduce_scatter_tensor_2d,
+ split_batch_2d,
+)
from ._utils import assert_summa_initialization, get_summa_dim_from_env
diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py
similarity index 59%
rename from colossalai/nn/layer/parallel_2p5d/__init__.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/__init__.py
index bec3b1c4b0b8..23e47e6ed06b 100644
--- a/colossalai/nn/layer/parallel_2p5d/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py
@@ -1,6 +1,13 @@
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
-from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
- VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
+from .layers import (
+ Classifier2p5D,
+ Embedding2p5D,
+ LayerNorm2p5D,
+ Linear2p5D,
+ PatchEmbedding2p5D,
+ VocabParallelClassifier2p5D,
+ VocabParallelEmbedding2p5D,
+)
__all__ = [
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
similarity index 99%
rename from colossalai/nn/layer/parallel_2p5d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
index 5a0f537cd6d9..55defa4a328d 100644
--- a/colossalai/nn/layer/parallel_2p5d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
@@ -2,12 +2,13 @@
import torch
import torch.distributed as dist
-from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
+from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.utils import get_current_device
-from torch import Tensor
-from torch.cuda.amp import custom_bwd, custom_fwd
def get_parallel_group(parallel_mode: ParallelMode):
diff --git a/colossalai/nn/layer/parallel_2p5d/_utils.py b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_2p5d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/_utils.py
diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_2p5d/layers.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/layers.py
index f849cbbe7b0d..04acc2bb0f4c 100644
--- a/colossalai/nn/layer/parallel_2p5d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
@@ -5,22 +5,34 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from colossalai.communication import broadcast
+from torch import Tensor
+from torch.nn import Parameter
+
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication import broadcast
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
-from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
- partition_tensor_parallel_state_dict)
+from colossalai.utils.checkpointing import (
+ broadcast_state_dict,
+ gather_tensor_parallel_state_dict,
+ partition_tensor_parallel_state_dict,
+)
from colossalai.utils.cuda import get_current_device
-from torch import Tensor
-from torch.nn import Parameter
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
-from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
- layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d)
+from ._operation import (
+ Matmul_AB_2p5D,
+ Matmul_ABT_2p5D,
+ add_bias_2p5d,
+ all_gather_tensor_2p5d,
+ classifier_2p5d,
+ layernorm_2p5d,
+ reduce_scatter_tensor_2p5d,
+ split_batch_2p5d,
+)
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/legacy/nn/layer/parallel_3d/__init__.py
similarity index 62%
rename from colossalai/nn/layer/parallel_3d/__init__.py
rename to colossalai/legacy/nn/layer/parallel_3d/__init__.py
index 9ae255b449ee..17fe8403c585 100644
--- a/colossalai/nn/layer/parallel_3d/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/__init__.py
@@ -1,6 +1,13 @@
from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d
-from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D,
- VocabParallelEmbedding3D)
+from .layers import (
+ Classifier3D,
+ Embedding3D,
+ LayerNorm3D,
+ Linear3D,
+ PatchEmbedding3D,
+ VocabParallelClassifier3D,
+ VocabParallelEmbedding3D,
+)
__all__ = [
'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D',
diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/legacy/nn/layer/parallel_3d/_operation.py
similarity index 99%
rename from colossalai/nn/layer/parallel_3d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_3d/_operation.py
index 5dc9a242851f..ca0b0e62783a 100755
--- a/colossalai/nn/layer/parallel_3d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/_operation.py
@@ -7,10 +7,10 @@
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from ._utils import get_parallel_mode_from_env, push_async_grad
diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/legacy/nn/layer/parallel_3d/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_3d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_3d/_utils.py
diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_3d/layers.py
rename to colossalai/legacy/nn/layer/parallel_3d/layers.py
index 99b0c3f8b7ec..b815a842ca52 100644
--- a/colossalai/nn/layer/parallel_3d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py
@@ -8,14 +8,14 @@
from torch import Tensor
from torch.nn import Parameter
-from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication import all_reduce, broadcast
+from colossalai.legacy.nn.layer.base_layer import ParallelLayer
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.nn.layer.base_layer import ParallelLayer
-from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
diff --git a/colossalai/nn/layer/parallel_sequence/__init__.py b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py
similarity index 74%
rename from colossalai/nn/layer/parallel_sequence/__init__.py
rename to colossalai/legacy/nn/layer/parallel_sequence/__init__.py
index 4fa9eed6f34b..d92d66d40a8e 100644
--- a/colossalai/nn/layer/parallel_sequence/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py
@@ -1,4 +1,4 @@
-from ._operation import RingQK, RingAV
+from ._operation import RingAV, RingQK
from .layers import TransformerSelfAttentionRing
__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK']
diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
similarity index 97%
rename from colossalai/nn/layer/parallel_sequence/_operation.py
rename to colossalai/legacy/nn/layer/parallel_sequence/_operation.py
index fc80494224c6..fcf2962017a3 100644
--- a/colossalai/nn/layer/parallel_sequence/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
@@ -3,13 +3,13 @@
import torch
from torch import distributed as dist
+from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.communication import ring_forward
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
+from colossalai.legacy.communication import ring_forward
+from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
from colossalai.utils import get_current_device
-from torch.cuda.amp import custom_bwd, custom_fwd
class RingQK(torch.autograd.Function):
diff --git a/colossalai/nn/layer/parallel_sequence/_utils.py b/colossalai/legacy/nn/layer/parallel_sequence/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_sequence/_utils.py
rename to colossalai/legacy/nn/layer/parallel_sequence/_utils.py
diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py
similarity index 98%
rename from colossalai/nn/layer/parallel_sequence/layers.py
rename to colossalai/legacy/nn/layer/parallel_sequence/layers.py
index 0887f8389dbe..e44e61c2fb7d 100644
--- a/colossalai/nn/layer/parallel_sequence/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py
@@ -2,20 +2,20 @@
# -*- encoding: utf-8 -*-
import math
-import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
+import colossalai
+from colossalai.context import seed
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
-from colossalai.registry import LAYERS
-from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import FusedScaleMaskSoftmax
-from colossalai.context import seed
+from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
+from colossalai.legacy.registry import LAYERS
@LAYERS.register_module
diff --git a/colossalai/legacy/nn/layer/utils/__init__.py b/colossalai/legacy/nn/layer/utils/__init__.py
new file mode 100644
index 000000000000..56e969bfd0bd
--- /dev/null
+++ b/colossalai/legacy/nn/layer/utils/__init__.py
@@ -0,0 +1,15 @@
+from .common import (
+ ACT2FN,
+ CheckpointModule,
+ _ntuple,
+ divide,
+ get_tensor_parallel_mode,
+ set_tensor_parallel_attribute_by_partition,
+ set_tensor_parallel_attribute_by_size,
+ to_2tuple,
+)
+
+__all__ = [
+ 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
+ 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
+]
diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/legacy/nn/layer/utils/common.py
similarity index 99%
rename from colossalai/nn/layer/utils/common.py
rename to colossalai/legacy/nn/layer/utils/common.py
index f2297304fdc9..d8f3ad2a7eca 100644
--- a/colossalai/nn/layer/utils/common.py
+++ b/colossalai/legacy/nn/layer/utils/common.py
@@ -6,10 +6,11 @@
import numpy as np
import torch
+from torch import Tensor, nn
+
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.utils import checkpoint
-from torch import Tensor, nn
class CheckpointModule(nn.Module):
diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/legacy/nn/layer/vanilla/__init__.py
similarity index 100%
rename from colossalai/nn/layer/vanilla/__init__.py
rename to colossalai/legacy/nn/layer/vanilla/__init__.py
diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py
similarity index 99%
rename from colossalai/nn/layer/vanilla/layers.py
rename to colossalai/legacy/nn/layer/vanilla/layers.py
index 225aed3916a6..0e11fc4d0dab 100644
--- a/colossalai/nn/layer/vanilla/layers.py
+++ b/colossalai/legacy/nn/layer/vanilla/layers.py
@@ -8,8 +8,8 @@
from torch.nn.parameter import Parameter
from colossalai.context import seed
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device
from ..utils import to_2tuple
diff --git a/colossalai/nn/layer/wrapper/__init__.py b/colossalai/legacy/nn/layer/wrapper/__init__.py
similarity index 100%
rename from colossalai/nn/layer/wrapper/__init__.py
rename to colossalai/legacy/nn/layer/wrapper/__init__.py
diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
similarity index 99%
rename from colossalai/nn/layer/wrapper/pipeline_wrapper.py
rename to colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
index ef1d794cc68f..68fea8622c5c 100644
--- a/colossalai/nn/layer/wrapper/pipeline_wrapper.py
+++ b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
@@ -1,6 +1,8 @@
-import torch.nn as nn
-import torch.distributed as dist
from typing import List, Tuple, Union
+
+import torch.distributed as dist
+import torch.nn as nn
+
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
diff --git a/colossalai/legacy/nn/loss/__init__.py b/colossalai/legacy/nn/loss/__init__.py
new file mode 100644
index 000000000000..1bd8872d9c3a
--- /dev/null
+++ b/colossalai/legacy/nn/loss/__init__.py
@@ -0,0 +1,41 @@
+from torch import nn
+from torch.nn.modules.loss import *
+from torch.nn.modules.loss import _Loss
+
+from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
+
+from .loss_1d import VocabParallelCrossEntropyLoss1D
+from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
+from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
+from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
+
+_parallel_cross_entropy = {
+ '2d': CrossEntropyLoss2D,
+ '2.5d': CrossEntropyLoss2p5D,
+ '3d': CrossEntropyLoss3D,
+}
+
+_vocab_parallel_cross_entropy = {
+ '1d': VocabParallelCrossEntropyLoss1D,
+ '2d': VocabParallelCrossEntropyLoss2D,
+ '2.5d': VocabParallelCrossEntropyLoss2p5D,
+ '3d': VocabParallelCrossEntropyLoss3D,
+}
+
+
+class CrossEntropyLoss(_Loss):
+
+ def __init__(self, reduction: bool = True, *args, **kwargs):
+ super().__init__()
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel is not None and env.vocab_parallel:
+ self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
+ elif tensor_parallel is None or tensor_parallel == '1d':
+ reduction = 'mean' if reduction else 'none'
+ self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
+ else:
+ self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
+
+ def forward(self, *args):
+ return self.loss(*args)
diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/legacy/nn/loss/loss_1d.py
similarity index 96%
rename from colossalai/nn/loss/loss_1d.py
rename to colossalai/legacy/nn/loss/loss_1d.py
index dd548c1d3dd4..8c9483fccaec 100644
--- a/colossalai/nn/loss/loss_1d.py
+++ b/colossalai/legacy/nn/loss/loss_1d.py
@@ -1,105 +1,106 @@
-import torch
-import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.registry import LOSSES
-from torch.cuda.amp import custom_bwd, custom_fwd
-from torch.nn.modules.loss import _Loss
-
-
-class _VocabParallelCrossEntropy1D(torch.autograd.Function):
-
- @staticmethod
- @custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, vocab_parallel_logits, targets, process_group):
- if process_group is None:
- process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
-
- # Maximum value along vocab dimension across all GPUs.
- logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
- torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
- # Subtract the maximum value.
- vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
-
- # Get the partition's vocab indices
- partition_vocab_size = vocab_parallel_logits.size()[-1]
- rank = dist.get_rank(process_group)
- vocab_start_index = partition_vocab_size * rank
- vocab_end_index = vocab_start_index + partition_vocab_size
-
- # Create a mask of valid vocab ids (1 means it needs to be masked).
- target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
- masked_target = targets.clone() - vocab_start_index
- masked_target[target_mask] = 0
-
- # Get predicted-logits = logits[target].
- # For Simplicity, we convert logits to a 2-D tensor with size
- # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
- logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
- masked_target_1d = masked_target.view(-1)
- arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
- predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
- predicted_logits_1d = predicted_logits_1d.clone().contiguous()
- predicted_logits = predicted_logits_1d.view_as(targets)
- predicted_logits[target_mask] = 0.0
- # All reduce is needed to get the chunks from other GPUs.
- torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
-
- # Sum of exponential of logits along vocab dimension across all GPUs.
- exp_logits = torch.exp(vocab_parallel_logits)
- sum_exp_logits = exp_logits.sum(dim=-1)
- torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
-
- # Loss = log(sum(exp(logits))) - predicted-logit.
- loss = torch.log(sum_exp_logits) - predicted_logits
- # Store softmax, target-mask and masked-target for backward pass.
- exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
- ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
- return loss
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad_output):
-
- # Retrieve tensors from the forward path.
- softmax, target_mask, masked_target_1d = ctx.saved_tensors
-
- # All the inputs have softmax as their gradient.
- grad_input = softmax
- # For simplicity, work with the 2D gradient.
- partition_vocab_size = softmax.size()[-1]
- grad_2d = grad_input.view(-1, partition_vocab_size)
-
- # Add the gradient from matching classes.
- arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
- grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
-
- # Finally elementwise multiplication with the output gradients.
- grad_input.mul_(grad_output.unsqueeze(dim=-1))
-
- return grad_input, None, None
-
-
-@LOSSES.register_module
-class VocabParallelCrossEntropyLoss1D(_Loss):
- """Vocab parallel cross entropy loss for 1D parallelism.
-
- Args:
- reduction (bool, optional): whether to average the loss, defaults to True.
- """
-
- def __init__(self, reduction=True):
- super().__init__()
- self.reduction_mean = reduction
-
- def forward(self, logits, targets, process_group=None):
- """Calculate loss between logits and targets.
-
- Args:
- logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
- """
- loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
- if self.reduction_mean:
- loss = loss.mean()
- return loss
+import torch
+import torch.distributed as dist
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torch.nn.modules.loss import _Loss
+
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import LOSSES
+
+
+class _VocabParallelCrossEntropy1D(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, vocab_parallel_logits, targets, process_group):
+ if process_group is None:
+ process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
+
+ # Maximum value along vocab dimension across all GPUs.
+ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+ torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
+ # Subtract the maximum value.
+ vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+ # Get the partition's vocab indices
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ rank = dist.get_rank(process_group)
+ vocab_start_index = partition_vocab_size * rank
+ vocab_end_index = vocab_start_index + partition_vocab_size
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
+ masked_target = targets.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(targets)
+ predicted_logits[target_mask] = 0.0
+ # All reduce is needed to get the chunks from other GPUs.
+ torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+ exp_logits = torch.exp(vocab_parallel_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+ torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+ # Loss = log(sum(exp(logits))) - predicted-logit.
+ loss = torch.log(sum_exp_logits) - predicted_logits
+ # Store softmax, target-mask and masked-target for backward pass.
+ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+ ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+ return loss
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+
+ # Retrieve tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+ # All the inputs have softmax as their gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+ grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
+
+ # Finally elementwise multiplication with the output gradients.
+ grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+ return grad_input, None, None
+
+
+@LOSSES.register_module
+class VocabParallelCrossEntropyLoss1D(_Loss):
+ """Vocab parallel cross entropy loss for 1D parallelism.
+
+ Args:
+ reduction (bool, optional): whether to average the loss, defaults to True.
+ """
+
+ def __init__(self, reduction=True):
+ super().__init__()
+ self.reduction_mean = reduction
+
+ def forward(self, logits, targets, process_group=None):
+ """Calculate loss between logits and targets.
+
+ Args:
+ logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+ """
+ loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
+ if self.reduction_mean:
+ loss = loss.mean()
+ return loss
diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py
similarity index 96%
rename from colossalai/nn/loss/loss_2d.py
rename to colossalai/legacy/nn/loss/loss_2d.py
index 7da8b2d697fa..6191602b71ee 100644
--- a/colossalai/nn/loss/loss_2d.py
+++ b/colossalai/legacy/nn/loss/loss_2d.py
@@ -1,15 +1,16 @@
import torch
import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
-from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
+from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
+from colossalai.legacy.registry import LOSSES
+from colossalai.utils import get_current_device
+
@LOSSES.register_module
class CrossEntropyLoss2D(_Loss):
diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py
similarity index 95%
rename from colossalai/nn/loss/loss_2p5d.py
rename to colossalai/legacy/nn/loss/loss_2p5d.py
index 63dc4f33ad32..2746b201152c 100644
--- a/colossalai/nn/loss/loss_2p5d.py
+++ b/colossalai/legacy/nn/loss/loss_2p5d.py
@@ -1,15 +1,16 @@
import torch
import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
-from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
+from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
+from colossalai.legacy.registry import LOSSES
+from colossalai.utils import get_current_device
+
@LOSSES.register_module
class CrossEntropyLoss2p5D(_Loss):
diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py
similarity index 95%
rename from colossalai/nn/loss/loss_3d.py
rename to colossalai/legacy/nn/loss/loss_3d.py
index f27d57ad6c99..2aeb1bd9825d 100644
--- a/colossalai/nn/loss/loss_3d.py
+++ b/colossalai/legacy/nn/loss/loss_3d.py
@@ -1,15 +1,16 @@
import torch
import torch.distributed as dist
-from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
-from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
+from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
+from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+from colossalai.legacy.registry import LOSSES
+from colossalai.utils import get_current_device
+
@LOSSES.register_module
class CrossEntropyLoss3D(_Loss):
diff --git a/colossalai/nn/metric/__init__.py b/colossalai/legacy/nn/metric/__init__.py
similarity index 87%
rename from colossalai/nn/metric/__init__.py
rename to colossalai/legacy/nn/metric/__init__.py
index 00833b6119c1..76c6dac89c5b 100644
--- a/colossalai/nn/metric/__init__.py
+++ b/colossalai/legacy/nn/metric/__init__.py
@@ -1,26 +1,28 @@
-from torch import nn
-
-from ._utils import calc_acc
-from .accuracy_2d import Accuracy2D
-from .accuracy_2p5d import Accuracy2p5D
-from .accuracy_3d import Accuracy3D
-from colossalai.nn.layer.utils import get_tensor_parallel_mode
-
-_parallel_accuracy = {
- '2d': Accuracy2D,
- '2.5d': Accuracy2p5D,
- '3d': Accuracy3D,
-}
-
-
-class Accuracy(nn.Module):
- def __init__(self):
- super().__init__()
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel not in _parallel_accuracy:
- self.acc = calc_acc
- else:
- self.acc = _parallel_accuracy[tensor_parallel]()
-
- def forward(self, *args):
- return self.acc(*args)
+from torch import nn
+
+from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
+
+from ._utils import calc_acc
+from .accuracy_2d import Accuracy2D
+from .accuracy_2p5d import Accuracy2p5D
+from .accuracy_3d import Accuracy3D
+
+_parallel_accuracy = {
+ '2d': Accuracy2D,
+ '2.5d': Accuracy2p5D,
+ '3d': Accuracy3D,
+}
+
+
+class Accuracy(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel not in _parallel_accuracy:
+ self.acc = calc_acc
+ else:
+ self.acc = _parallel_accuracy[tensor_parallel]()
+
+ def forward(self, *args):
+ return self.acc(*args)
diff --git a/colossalai/nn/metric/_utils.py b/colossalai/legacy/nn/metric/_utils.py
similarity index 95%
rename from colossalai/nn/metric/_utils.py
rename to colossalai/legacy/nn/metric/_utils.py
index eac591b64c65..8706ffc101b0 100644
--- a/colossalai/nn/metric/_utils.py
+++ b/colossalai/legacy/nn/metric/_utils.py
@@ -1,7 +1,7 @@
-import torch
-
-
-def calc_acc(logits, targets):
- preds = torch.argmax(logits, dim=-1)
- correct = torch.sum(targets == preds)
- return correct
+import torch
+
+
+def calc_acc(logits, targets):
+ preds = torch.argmax(logits, dim=-1)
+ correct = torch.sum(targets == preds)
+ return correct
diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/legacy/nn/metric/accuracy_2d.py
similarity index 89%
rename from colossalai/nn/metric/accuracy_2d.py
rename to colossalai/legacy/nn/metric/accuracy_2d.py
index a86832973cfd..838c48834a96 100644
--- a/colossalai/nn/metric/accuracy_2d.py
+++ b/colossalai/legacy/nn/metric/accuracy_2d.py
@@ -1,7 +1,8 @@
import torch
-from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from torch import nn
+from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
+
from ._utils import calc_acc
diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/legacy/nn/metric/accuracy_2p5d.py
similarity index 88%
rename from colossalai/nn/metric/accuracy_2p5d.py
rename to colossalai/legacy/nn/metric/accuracy_2p5d.py
index 3044da065de1..183380cd9846 100644
--- a/colossalai/nn/metric/accuracy_2p5d.py
+++ b/colossalai/legacy/nn/metric/accuracy_2p5d.py
@@ -1,7 +1,8 @@
import torch
-from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from torch import nn
+from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
+
from ._utils import calc_acc
diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/legacy/nn/metric/accuracy_3d.py
similarity index 85%
rename from colossalai/nn/metric/accuracy_3d.py
rename to colossalai/legacy/nn/metric/accuracy_3d.py
index 5506fc1d2ffc..1aaac73ecabd 100644
--- a/colossalai/nn/metric/accuracy_3d.py
+++ b/colossalai/legacy/nn/metric/accuracy_3d.py
@@ -1,33 +1,35 @@
-import torch
-from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
-from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
-from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
-from torch import nn
-
-from ._utils import calc_acc
-
-
-class Accuracy3D(nn.Module):
- """Accuracy for 3D parallelism
- """
- def __init__(self):
- super().__init__()
- self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
- self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
-
- def forward(self, logits, targets):
- """Calculate the accuracy of predicted labels.
-
- Args:
- logits (:class:`torch.tensor`): Predicted labels.
- targets (:class:`torch.tensor`): True labels from data.
-
- Returns:
- float: the accuracy of prediction.
- """
- with torch.no_grad():
- targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
- targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
- correct = calc_acc(logits, targets)
- correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode)
- return correct
+import torch
+from torch import nn
+
+from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
+from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+
+from ._utils import calc_acc
+
+
+class Accuracy3D(nn.Module):
+ """Accuracy for 3D parallelism
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+
+ def forward(self, logits, targets):
+ """Calculate the accuracy of predicted labels.
+
+ Args:
+ logits (:class:`torch.tensor`): Predicted labels.
+ targets (:class:`torch.tensor`): True labels from data.
+
+ Returns:
+ float: the accuracy of prediction.
+ """
+ with torch.no_grad():
+ targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
+ targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
+ correct = calc_acc(logits, targets)
+ correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode)
+ return correct
diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/legacy/nn/parallel/__init__.py
similarity index 100%
rename from colossalai/nn/parallel/__init__.py
rename to colossalai/legacy/nn/parallel/__init__.py
diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py
similarity index 100%
rename from colossalai/nn/parallel/data_parallel.py
rename to colossalai/legacy/nn/parallel/data_parallel.py
diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/legacy/nn/parallel/layers/__init__.py
similarity index 56%
rename from colossalai/nn/parallel/layers/__init__.py
rename to colossalai/legacy/nn/parallel/layers/__init__.py
index 29b8353e63c5..f38124efedf7 100644
--- a/colossalai/nn/parallel/layers/__init__.py
+++ b/colossalai/legacy/nn/parallel/layers/__init__.py
@@ -1,10 +1,17 @@
+from .cache_embedding import (
+ CachedEmbeddingBag,
+ CachedParamMgr,
+ EvictionStrategy,
+ LimitBuffIndexCopyer,
+ ParallelCachedEmbeddingBag,
+ ParallelCachedEmbeddingBagTablewise,
+ ParallelCachedEmbeddingBagTablewiseSpiltCache,
+ TablewiseEmbeddingBagConfig,
+)
from .colo_module import ColoModule
-from .linear import ColoLinear
from .embedding import ColoEmbedding
-from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
-
-from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
- ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache
+from .linear import ColoLinear
+from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module
__all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py
similarity index 100%
rename from colossalai/nn/parallel/layers/cache_embedding/__init__.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py
index 5bbc931a79dc..d87930c1c6b3 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py
@@ -1,8 +1,8 @@
from .cache_mgr import CachedParamMgr, EvictionStrategy
-from .copyer import LimitBuffIndexCopyer
from .cached_embedding import CachedEmbeddingBag
-from .parallel_cached_embedding import ParallelCachedEmbeddingBag
+from .copyer import LimitBuffIndexCopyer
from .embedding_config import TablewiseEmbeddingBagConfig
+from .parallel_cached_embedding import ParallelCachedEmbeddingBag
from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache
diff --git a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/base_embedding.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py
index 705835a0ed22..9558c541e703 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py
@@ -1,4 +1,5 @@
import abc
+
import torch.nn as nn
diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py
index a6159856dcce..16530c4ce7b8 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py
@@ -1,12 +1,14 @@
+import sys
+from contextlib import contextmanager
+from enum import Enum
+from typing import List, Optional
+
import numpy as np
import torch
-from torch.profiler import record_function
-from typing import List, Optional
from contexttimer import Timer
+from torch.profiler import record_function
+
from .copyer import LimitBuffIndexCopyer
-from enum import Enum
-import sys
-from contextlib import contextmanager
class EvictionStrategy(Enum):
@@ -35,7 +37,7 @@ def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
class CachedParamMgr(torch.nn.Module):
"""
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
- CPU maintains the entire original weight.
+ CPU maintains the entire original weight.
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
During training, GPU needs to transmit embedding rows between CPU and GPU.
Args:
@@ -115,7 +117,7 @@ def timer(self, name):
self._elapsed_dict[name] += t.elapsed
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
- """_find_evict_gpu_idxs
+ """_find_evict_gpu_idxs
Find the gpu idxs to be evicted, according to their freq.
Args:
evict_num (int): how many rows has to be evicted
@@ -202,7 +204,7 @@ def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7
"""reorder
reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase.
-
+
Note:
If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
@@ -516,7 +518,7 @@ def _evict(self) -> int:
"""
deprecated
evict one row from cuda to cpu.
- Returns:
+ Returns:
(int) : the slot id be evicted.
"""
mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1)
diff --git a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py
similarity index 98%
rename from colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py
index a74cb8d94bab..bc7d178906da 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py
@@ -1,10 +1,11 @@
+from typing import Iterator, List, Optional, Tuple, Union
+
import torch
import torch.nn.functional as F
-from typing import List, Optional, Iterator, Tuple, Union
+from torch.nn.parameter import Parameter
from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr, EvictionStrategy
-from torch.nn.parameter import Parameter
class CachedEmbeddingBag(BaseEmbeddingBag):
@@ -27,7 +28,7 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False.
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32.
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu.
- cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
+ cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occurs in dataset. Defaults to None.
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0.
@@ -85,10 +86,10 @@ def _preprocess(self,
buffer_size=50_000,
pin_weight=False):
"""
- Called after initialized.
+ Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
Then, let the weights of the Module be managed by a CachedParamMgr.
-
+
Args:
cuda_row_num (int): number of rows can be hosted in CUDA memory
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
diff --git a/colossalai/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py
similarity index 97%
rename from colossalai/nn/parallel/layers/cache_embedding/copyer.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py
index aa1f794482f9..804a07f88207 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/copyer.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py
@@ -3,7 +3,7 @@
class LimitBuffIndexCopyer(object):
- """LimitBuffIndexCopyer
+ """LimitBuffIndexCopyer
Index Copy using limited temp buffer on CUDA.
Args:
@@ -15,7 +15,7 @@ def __init__(self, size: int) -> None:
@torch.no_grad()
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
- """copy
+ """copy
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
The valid rows in the src tensor are continuous, while rows in tgt tensor is scattered.
diff --git a/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py
similarity index 100%
rename from colossalai/nn/parallel/layers/cache_embedding/embedding_config.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
similarity index 96%
rename from colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
index d7f77e195f4b..79d7672b26bc 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
@@ -1,12 +1,13 @@
+from typing import Iterator, List, Optional, Tuple
+
import torch
import torch.nn.functional as F
-from typing import List, Optional, Iterator, Tuple
-from .cached_embedding import CachedEmbeddingBag
-from colossalai.nn._ops._utils import dual_all_to_all
+from colossalai.legacy.nn._ops._utils import dual_all_to_all
+from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
-from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
+from .cached_embedding import CachedEmbeddingBag
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
index 949f85ad4baf..116d836b7139 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
@@ -1,15 +1,16 @@
+import time
+from typing import List
+
import torch
import torch.distributed as dist
import torch.nn.functional as F
-from .cached_embedding import CachedEmbeddingBag
-from .cache_mgr import EvictionStrategy
-from .embedding_config import TablewiseEmbeddingBagConfig
+from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
-from colossalai.nn._ops._utils import dual_all_to_all_tablewise
-from typing import List
-import time
+from .cache_mgr import EvictionStrategy
+from .cached_embedding import CachedEmbeddingBag
+from .embedding_config import TablewiseEmbeddingBagConfig
class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
index 80a54b4fadd4..0014c784fba1 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
@@ -1,17 +1,17 @@
+import abc
+from typing import List
+
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.profiler import record_function
-from .cached_embedding import CachedEmbeddingBag
-
+from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
-from colossalai.nn._ops._utils import dual_all_to_all_tablewise
-from .embedding_config import TablewiseEmbeddingBagConfig
-from .cache_mgr import EvictionStrategy
-from typing import List
-import abc
+from .cache_mgr import EvictionStrategy
+from .cached_embedding import CachedEmbeddingBag
+from .embedding_config import TablewiseEmbeddingBagConfig
class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
diff --git a/colossalai/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py
similarity index 98%
rename from colossalai/nn/parallel/layers/colo_module.py
rename to colossalai/legacy/nn/parallel/layers/colo_module.py
index 8f0f5d5f520a..a0a3eb40cf08 100644
--- a/colossalai/nn/parallel/layers/colo_module.py
+++ b/colossalai/legacy/nn/parallel/layers/colo_module.py
@@ -1,6 +1,7 @@
-from colossalai.tensor.distspec import _DistSpec
+from typing import Dict, List
+
from colossalai.tensor import ComputePattern
-from typing import List, Dict
+from colossalai.tensor.distspec import _DistSpec
class ColoModule(object):
diff --git a/colossalai/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py
similarity index 92%
rename from colossalai/nn/parallel/layers/embedding.py
rename to colossalai/legacy/nn/parallel/layers/embedding.py
index ccacc1ead297..3e4e7ffd8de7 100644
--- a/colossalai/nn/parallel/layers/embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/embedding.py
@@ -1,5 +1,6 @@
+from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
+
from .colo_module import ColoModule
-from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoEmbedding(ColoModule):
diff --git a/colossalai/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py
similarity index 93%
rename from colossalai/nn/parallel/layers/linear.py
rename to colossalai/legacy/nn/parallel/layers/linear.py
index 84a8c042587d..e391cf808933 100644
--- a/colossalai/nn/parallel/layers/linear.py
+++ b/colossalai/legacy/nn/parallel/layers/linear.py
@@ -1,5 +1,6 @@
+from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
+
from .colo_module import ColoModule
-from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoLinear(ColoModule):
diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py
similarity index 99%
rename from colossalai/nn/parallel/layers/module_utils.py
rename to colossalai/legacy/nn/parallel/layers/module_utils.py
index 38d128cc705e..191266fa70fd 100644
--- a/colossalai/nn/parallel/layers/module_utils.py
+++ b/colossalai/legacy/nn/parallel/layers/module_utils.py
@@ -1,9 +1,11 @@
from typing import Dict
-from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup
-from colossalai.tensor import distspec
-from . import ColoModule
+
import torch
+from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec
+
+from . import ColoModule
+
_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
diff --git a/colossalai/nn/parallel/reducer.py b/colossalai/legacy/nn/parallel/reducer.py
similarity index 100%
rename from colossalai/nn/parallel/reducer.py
rename to colossalai/legacy/nn/parallel/reducer.py
diff --git a/colossalai/registry/__init__.py b/colossalai/legacy/registry/__init__.py
similarity index 100%
rename from colossalai/registry/__init__.py
rename to colossalai/legacy/registry/__init__.py
diff --git a/colossalai/registry/registry.py b/colossalai/legacy/registry/registry.py
similarity index 98%
rename from colossalai/registry/registry.py
rename to colossalai/legacy/registry/registry.py
index 8a4173f7ab99..50d6b74c5617 100644
--- a/colossalai/registry/registry.py
+++ b/colossalai/legacy/registry/registry.py
@@ -6,7 +6,7 @@
class Registry:
- """This is a registry class used to register classes and modules so that a universal
+ """This is a registry class used to register classes and modules so that a universal
object builder can be enabled.
Args:
@@ -42,7 +42,7 @@ def register_module(self, module_class):
return module_class
def get_module(self, module_name: str):
- """Retrieves a module with name `module_name` and returns the module if it has
+ """Retrieves a module with name `module_name` and returns the module if it has
already been registered before.
Args:
diff --git a/colossalai/trainer/__init__.py b/colossalai/legacy/trainer/__init__.py
similarity index 100%
rename from colossalai/trainer/__init__.py
rename to colossalai/legacy/trainer/__init__.py
diff --git a/colossalai/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py
similarity index 98%
rename from colossalai/trainer/_trainer.py
rename to colossalai/legacy/trainer/_trainer.py
index bfe1c403fd48..1847e56222a1 100644
--- a/colossalai/trainer/_trainer.py
+++ b/colossalai/legacy/trainer/_trainer.py
@@ -1,14 +1,13 @@
-from typing import Union, List, Any
+from typing import Any, List, Union
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
-from colossalai.engine import Engine
+from colossalai.legacy.engine import Engine
+from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import DistributedLogger
-from colossalai.utils import MultiTimer
-from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
-from colossalai.trainer.hooks import BaseHook
+from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
class Trainer:
diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/legacy/trainer/hooks/__init__.py
similarity index 75%
rename from colossalai/trainer/hooks/__init__.py
rename to colossalai/legacy/trainer/hooks/__init__.py
index 4d36093833d9..bf9cc6421b67 100644
--- a/colossalai/trainer/hooks/__init__.py
+++ b/colossalai/legacy/trainer/hooks/__init__.py
@@ -1,7 +1,12 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook
-from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
- TensorboardHook)
+from ._log_hook import (
+ LogMemoryByEpochHook,
+ LogMetricByEpochHook,
+ LogMetricByStepHook,
+ LogTimingByEpochHook,
+ TensorboardHook,
+)
from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
diff --git a/colossalai/trainer/hooks/_base_hook.py b/colossalai/legacy/trainer/hooks/_base_hook.py
similarity index 100%
rename from colossalai/trainer/hooks/_base_hook.py
rename to colossalai/legacy/trainer/hooks/_base_hook.py
diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
similarity index 96%
rename from colossalai/trainer/hooks/_checkpoint_hook.py
rename to colossalai/legacy/trainer/hooks/_checkpoint_hook.py
index 3bcb32cd2dcb..6b150d29139f 100644
--- a/colossalai/trainer/hooks/_checkpoint_hook.py
+++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
@@ -1,11 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
-from colossalai.logging import get_dist_logger
-from colossalai.registry import HOOKS
-from colossalai.trainer.hooks import BaseHook
+from colossalai.legacy.registry import HOOKS
+from colossalai.legacy.trainer.hooks import BaseHook
+from colossalai.logging import get_dist_logger
from colossalai.utils.checkpointing import save_checkpoint
+
from ._lr_scheduler_hook import LRSchedulerHook
diff --git a/colossalai/trainer/hooks/_commons_.py b/colossalai/legacy/trainer/hooks/_commons_.py
similarity index 100%
rename from colossalai/trainer/hooks/_commons_.py
rename to colossalai/legacy/trainer/hooks/_commons_.py
diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py
similarity index 98%
rename from colossalai/trainer/hooks/_log_hook.py
rename to colossalai/legacy/trainer/hooks/_log_hook.py
index 5b1f33983422..7d9ad19aa9e9 100644
--- a/colossalai/trainer/hooks/_log_hook.py
+++ b/colossalai/legacy/trainer/hooks/_log_hook.py
@@ -3,17 +3,17 @@
import os
import os.path as osp
-
from typing import List
+
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import HOOKS
+from colossalai.legacy.registry import HOOKS
+from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger
-from colossalai.utils import report_memory_usage, is_dp_rank_0, \
- is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
+from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
+
from ._base_hook import BaseHook
from ._commons_ import _format_number
-from colossalai.trainer.hooks._metric_hook import ThroughputMetric
class LogByEpochHook(BaseHook):
diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
similarity index 97%
rename from colossalai/trainer/hooks/_lr_scheduler_hook.py
rename to colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
index c6da33442dc3..6d60966da12a 100644
--- a/colossalai/trainer/hooks/_lr_scheduler_hook.py
+++ b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
@@ -1,6 +1,7 @@
-from colossalai.registry import HOOKS
from torch import Tensor
+from colossalai.legacy.registry import HOOKS
+
from ._metric_hook import LearningRateMetric, MetricHook
diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py
similarity index 97%
rename from colossalai/trainer/hooks/_metric_hook.py
rename to colossalai/legacy/trainer/hooks/_metric_hook.py
index 526d6c746ec6..f1bd19387cb5 100644
--- a/colossalai/trainer/hooks/_metric_hook.py
+++ b/colossalai/legacy/trainer/hooks/_metric_hook.py
@@ -6,10 +6,11 @@
import torch
import torch.distributed as dist
-from colossalai.communication import all_reduce
+
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import HOOKS
+from colossalai.legacy.communication import all_reduce
+from colossalai.legacy.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook
@@ -19,8 +20,8 @@
class Metric(ABC):
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with
- :class:`MetricHook` to help it update its states and show the
- metric. So please use corresponding hook class to make the metric
+ :class:`MetricHook` to help it update its states and show the
+ metric. So please use corresponding hook class to make the metric
collector works.
Args:
@@ -220,9 +221,9 @@ def is_better(a, b) -> bool:
class MetricHook(BaseHook):
- """Specialized hook classes for :class:`Metric`.
- Some help metric collectors initialize, reset and
- update their states. Others are used to display and
+ """Specialized hook classes for :class:`Metric`.
+ Some help metric collectors initialize, reset and
+ update their states. Others are used to display and
record the metric.
Args:
@@ -355,7 +356,7 @@ def get_last_step_value(self) -> float:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
- gpc.get_world_size(ParallelMode.DATA)
+ gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
@@ -366,7 +367,7 @@ def get_last_step_info(self) -> str:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
- gpc.get_world_size(ParallelMode.DATA)
+ gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py
index af7b7de54a8d..f9abe4a2a2b6 100644
--- a/colossalai/logging/logger.py
+++ b/colossalai/logging/logger.py
@@ -6,8 +6,7 @@
from pathlib import Path
from typing import List, Union
-import colossalai
-from colossalai.context.parallel_mode import ParallelMode
+import torch.distributed as dist
class DistributedLogger:
@@ -63,6 +62,7 @@ def __init__(self, name):
self._logger.propagate = False
DistributedLogger.__instances[name] = self
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
@staticmethod
def __get_call_info():
@@ -109,16 +109,10 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF
# create log directory
path.mkdir(parents=True, exist_ok=True)
- # set the default file name if path is a directory
- if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
- rank = 0
- else:
- rank = colossalai.core.global_context.get_global_rank()
-
if suffix is not None:
- log_file_name = f'rank_{rank}_{suffix}.log'
+ log_file_name = f'rank_{self.rank}_{suffix}.log'
else:
- log_file_name = f'rank_{rank}.log'
+ log_file_name = f'rank_{self.rank}.log'
path = path.joinpath(log_file_name)
# add file handler
@@ -128,19 +122,14 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
- def _log(self,
- level,
- message: str,
- parallel_mode: ParallelMode = ParallelMode.GLOBAL,
- ranks: List[int] = None) -> None:
+ def _log(self, level, message: str, ranks: List[int] = None) -> None:
if ranks is None:
getattr(self._logger, level)(message)
else:
- local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
- if local_rank in ranks:
+ if self.rank in ranks:
getattr(self._logger, level)(message)
- def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def info(self, message: str, ranks: List[int] = None) -> None:
"""Log an info message.
Args:
@@ -150,10 +139,10 @@ def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('info', message_prefix, parallel_mode, ranks)
- self._log('info', message, parallel_mode, ranks)
+ self._log('info', message_prefix, ranks)
+ self._log('info', message, ranks)
- def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def warning(self, message: str, ranks: List[int] = None) -> None:
"""Log a warning message.
Args:
@@ -163,10 +152,10 @@ def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBA
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('warning', message_prefix, parallel_mode, ranks)
- self._log('warning', message, parallel_mode, ranks)
+ self._log('warning', message_prefix, ranks)
+ self._log('warning', message, ranks)
- def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def debug(self, message: str, ranks: List[int] = None) -> None:
"""Log a debug message.
Args:
@@ -176,10 +165,10 @@ def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('debug', message_prefix, parallel_mode, ranks)
- self._log('debug', message, parallel_mode, ranks)
+ self._log('debug', message_prefix, ranks)
+ self._log('debug', message, ranks)
- def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def error(self, message: str, ranks: List[int] = None) -> None:
"""Log an error message.
Args:
@@ -189,5 +178,5 @@ def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('error', message_prefix, parallel_mode, ranks)
- self._log('error', message, parallel_mode, ranks)
+ self._log('error', message_prefix, ranks)
+ self._log('error', message, ranks)
diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py
index 910ad203180c..c6c4d3042556 100644
--- a/colossalai/nn/__init__.py
+++ b/colossalai/nn/__init__.py
@@ -1,6 +1,5 @@
-from ._ops import *
+from .init import *
from .layer import *
from .loss import *
from .lr_scheduler import *
-from .metric import *
from .optimizer import *
diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py
index b705632f8040..edd986ef5e82 100644
--- a/colossalai/nn/layer/__init__.py
+++ b/colossalai/nn/layer/__init__.py
@@ -1,10 +1,2 @@
-from .colossalai_layer import *
-from .parallel_1d import *
-from .parallel_2d import *
-from .parallel_2p5d import *
-from .parallel_3d import *
-from .parallel_sequence import *
from .moe import *
from .utils import *
-from .vanilla import *
-from .wrapper import *
diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py
deleted file mode 100644
index 2353851df665..000000000000
--- a/colossalai/nn/layer/parallel_1d/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row,
- PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D)
-
-__all__ = [
- 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
- 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
-]
diff --git a/colossalai/nn/layer/utils.py b/colossalai/nn/layer/utils.py
new file mode 100644
index 000000000000..dc12ff8daa4e
--- /dev/null
+++ b/colossalai/nn/layer/utils.py
@@ -0,0 +1,14 @@
+def divide(numerator, denominator):
+ """Only allow exact division.
+
+ Args:
+ numerator (int): Numerator of the division.
+ denominator (int): Denominator of the division.
+
+ Returns:
+ int: the result of exact division.
+ """
+ assert denominator != 0, 'denominator can not be zero'
+ assert numerator % denominator == 0, \
+ '{} is not divisible by {}'.format(numerator, denominator)
+ return numerator // denominator
diff --git a/colossalai/nn/layer/utils/__init__.py b/colossalai/nn/layer/utils/__init__.py
deleted file mode 100644
index 7e999ee82149..000000000000
--- a/colossalai/nn/layer/utils/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
- set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
-
-__all__ = [
- 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
- 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
-]
diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py
index 373e4ec9468b..ee2add48ab91 100644
--- a/colossalai/nn/loss/__init__.py
+++ b/colossalai/nn/loss/__init__.py
@@ -1,41 +1 @@
-from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.nn.layer.utils import get_tensor_parallel_mode
-from torch import nn
-from torch.nn.modules.loss import *
-from torch.nn.modules.loss import _Loss
-
-from .loss_1d import VocabParallelCrossEntropyLoss1D
-from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
-from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
-from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
from .loss_moe import MoeCrossEntropyLoss, MoeLoss
-
-_parallel_cross_entropy = {
- '2d': CrossEntropyLoss2D,
- '2.5d': CrossEntropyLoss2p5D,
- '3d': CrossEntropyLoss3D,
-}
-
-_vocab_parallel_cross_entropy = {
- '1d': VocabParallelCrossEntropyLoss1D,
- '2d': VocabParallelCrossEntropyLoss2D,
- '2.5d': VocabParallelCrossEntropyLoss2p5D,
- '3d': VocabParallelCrossEntropyLoss3D,
-}
-
-
-class CrossEntropyLoss(_Loss):
-
- def __init__(self, reduction: bool = True, *args, **kwargs):
- super().__init__()
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel is not None and env.vocab_parallel:
- self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
- elif tensor_parallel is None or tensor_parallel == '1d':
- reduction = 'mean' if reduction else 'none'
- self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
- else:
- self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
-
- def forward(self, *args):
- return self.loss(*args)
diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py
index a8b18a3e37ee..40cea788c3c3 100644
--- a/colossalai/nn/loss/loss_moe.py
+++ b/colossalai/nn/loss/loss_moe.py
@@ -1,80 +1,81 @@
-import torch.nn as nn
-from colossalai.registry import LOSSES
-from torch.nn.modules.loss import _Loss
-from colossalai.context.moe_context import MOE_CONTEXT
-
-
-@LOSSES.register_module
-class MoeCrossEntropyLoss(_Loss):
- r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
-
- Args:
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
- aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
-
- The ``args`` and ``kwargs`` should include parameters below:
- ::
-
- weight (Tensor, optional)
- size_average (bool, optional)
- ignore_index (int, optional)
- reduce (bool, optional)
- reduction (str, optional)
- label_smoothing (float, optional)
-
- More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
- `Cross_entropy `_.
- """
-
- def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
- super().__init__()
- self.loss = nn.CrossEntropyLoss(*args, **kwargs)
- self.aux_weight = aux_weight
-
- def forward(self, *args):
- """
- The ``args`` should at least include parameters below:
- ::
-
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
- More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
- `Cross_entropy `_.
- """
- main_loss = self.loss(*args)
- aux_loss = MOE_CONTEXT.get_loss()
- return main_loss + self.aux_weight * aux_loss
-
-
-@LOSSES.register_module
-class MoeLoss(_Loss):
- """A wrapper class for any loss module to add with auxiliary loss.
-
- Args:
- aux_weight (float): Weight of auxiliary loss in total loss.
- loss_fn (``Callable``): Loss function.
- args (list): Args in loss function.
- kwargs (dict): Kwargs in loss function
- """
-
- def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
- super().__init__()
- self.loss_fn = loss_fn(*args, **kwargs)
- self.aux_weight = aux_weight
-
- def forward(self, *args, **kwargs):
- """
- The ``args`` and ``kwargs`` should at least include parameters below:
- ::
-
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
- Note:
- The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
- """
- main_loss = self.loss_fn(*args, **kwargs)
- aux_loss = MOE_CONTEXT.get_loss()
- return main_loss + self.aux_weight * aux_loss
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+
+from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.legacy.registry import LOSSES
+
+
+@LOSSES.register_module
+class MoeCrossEntropyLoss(_Loss):
+ r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
+
+ Args:
+ input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+ aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
+
+ The ``args`` and ``kwargs`` should include parameters below:
+ ::
+
+ weight (Tensor, optional)
+ size_average (bool, optional)
+ ignore_index (int, optional)
+ reduce (bool, optional)
+ reduction (str, optional)
+ label_smoothing (float, optional)
+
+ More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
+ `Cross_entropy `_.
+ """
+
+ def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
+ super().__init__()
+ self.loss = nn.CrossEntropyLoss(*args, **kwargs)
+ self.aux_weight = aux_weight
+
+ def forward(self, *args):
+ """
+ The ``args`` should at least include parameters below:
+ ::
+
+ input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+
+ More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
+ `Cross_entropy `_.
+ """
+ main_loss = self.loss(*args)
+ aux_loss = MOE_CONTEXT.get_loss()
+ return main_loss + self.aux_weight * aux_loss
+
+
+@LOSSES.register_module
+class MoeLoss(_Loss):
+ """A wrapper class for any loss module to add with auxiliary loss.
+
+ Args:
+ aux_weight (float): Weight of auxiliary loss in total loss.
+ loss_fn (``Callable``): Loss function.
+ args (list): Args in loss function.
+ kwargs (dict): Kwargs in loss function
+ """
+
+ def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
+ super().__init__()
+ self.loss_fn = loss_fn(*args, **kwargs)
+ self.aux_weight = aux_weight
+
+ def forward(self, *args, **kwargs):
+ """
+ The ``args`` and ``kwargs`` should at least include parameters below:
+ ::
+
+ input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+
+ Note:
+ The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
+ """
+ main_loss = self.loss_fn(*args, **kwargs)
+ aux_loss = MOE_CONTEXT.get_loss()
+ return main_loss + self.aux_weight * aux_loss
diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py
index aab523bef8b3..fb587e1a1341 100644
--- a/colossalai/nn/lr_scheduler/cosine.py
+++ b/colossalai/nn/lr_scheduler/cosine.py
@@ -1,10 +1,8 @@
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
-from colossalai.registry import LR_SCHEDULERS
from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler
-@LR_SCHEDULERS.register_module
class CosineAnnealingLR(_CosineAnnealingLR):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
@@ -48,7 +46,6 @@ def __init__(self, optimizer, total_steps: int, eta_min: int = 0, last_epoch: in
super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class CosineAnnealingWarmupLR(WarmupScheduler):
"""Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied.
@@ -69,7 +66,6 @@ def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min:
super().__init__(optimizer, warmup_steps, base_scheduler)
-@LR_SCHEDULERS.register_module
class FlatAnnealingLR(DelayerScheduler):
"""Flat and cosine annealing learning rate scheduler. The learning rate will be a fixed value before starting decay.
@@ -90,7 +86,6 @@ def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_ep
super().__init__(optimizer, flat_steps, base_scheduler, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class FlatAnnealingWarmupLR(WarmupDelayerScheduler):
"""Flat and cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be
applied, and then the learning rate will be a fixed value before starting decay.
diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py
index 556938b8a60c..21a865e4c12b 100644
--- a/colossalai/nn/lr_scheduler/linear.py
+++ b/colossalai/nn/lr_scheduler/linear.py
@@ -1,9 +1,6 @@
from torch.optim.lr_scheduler import _LRScheduler
-from colossalai.registry import LR_SCHEDULERS
-
-@LR_SCHEDULERS.register_module
class LinearWarmupLR(_LRScheduler):
"""Linearly warmup learning rate and then linearly decay.
diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py
index 29531a9e3855..c428c911c94d 100644
--- a/colossalai/nn/lr_scheduler/multistep.py
+++ b/colossalai/nn/lr_scheduler/multistep.py
@@ -2,11 +2,9 @@
from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR
-from colossalai.registry import LR_SCHEDULERS
from .delayed import WarmupScheduler
-@LR_SCHEDULERS.register_module
class MultiStepLR(_MultiStepLR):
"""Decays the learning rate of each parameter group by gamma once the
number of epoch reaches one of the milestones. Notice that such decay can
@@ -32,7 +30,6 @@ def __init__(self,
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class MultiStepWarmupLR(WarmupScheduler):
"""Multistep learning rate scheduler with warmup.
diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py
index 8007fd36008e..6835b3ee1cf2 100644
--- a/colossalai/nn/lr_scheduler/onecycle.py
+++ b/colossalai/nn/lr_scheduler/onecycle.py
@@ -1,9 +1,6 @@
from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR
-from colossalai.registry import LR_SCHEDULERS
-
-@LR_SCHEDULERS.register_module
class OneCycleLR(_OneCycleLR):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py
index 16352bc5175f..4f2249720ef6 100644
--- a/colossalai/nn/lr_scheduler/poly.py
+++ b/colossalai/nn/lr_scheduler/poly.py
@@ -1,10 +1,8 @@
from torch.optim.lr_scheduler import _LRScheduler
-from colossalai.registry import LR_SCHEDULERS
from .delayed import WarmupScheduler
-@LR_SCHEDULERS.register_module
class PolynomialLR(_LRScheduler):
"""Polynomial learning rate scheduler.
@@ -40,7 +38,6 @@ def _get_closed_form_lr(self):
for base_lr in self.base_lrs]
-@LR_SCHEDULERS.register_module
class PolynomialWarmupLR(WarmupScheduler):
"""Polynomial learning rate scheduler with warmup.
diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py
index 05d2a49c1ea5..8846e13c7511 100644
--- a/colossalai/nn/lr_scheduler/torch.py
+++ b/colossalai/nn/lr_scheduler/torch.py
@@ -1,12 +1,9 @@
+from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
from torch.optim.lr_scheduler import StepLR as _StepLR
-from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
-
-from colossalai.registry import LR_SCHEDULERS
-@LR_SCHEDULERS.register_module
class LambdaLR(_LambdaLR):
"""Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
@@ -24,7 +21,6 @@ def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1)
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class MultiplicativeLR(_MultiplicativeLR):
"""Multiply the learning rate of each parameter group by the factor given
in the specified function. When last_epoch=-1, sets initial lr as lr.
@@ -42,7 +38,6 @@ def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1)
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class StepLR(_StepLR):
"""Decays the learning rate of each parameter group by gamma every
step_size epochs. Notice that such decay can happen simultaneously with
@@ -61,7 +56,6 @@ def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.
super().__init__(optimizer, step_size, gamma=gamma, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr
diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py
index 3a6d37103398..9767fcb8b1e2 100644
--- a/colossalai/nn/optimizer/cpu_adam.py
+++ b/colossalai/nn/optimizer/cpu_adam.py
@@ -4,12 +4,10 @@
import torch
from colossalai.kernel.op_builder import CPUAdamBuilder
-from colossalai.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer
-@OPTIMIZERS.register_module
class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm.
diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py
index 82a6250f1fd1..3a05a34f52d2 100644
--- a/colossalai/nn/optimizer/fused_adam.py
+++ b/colossalai/nn/optimizer/fused_adam.py
@@ -8,11 +8,9 @@
'''
import torch
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
-@OPTIMIZERS.register_module
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py
index 72520064e98b..a2807d70f454 100644
--- a/colossalai/nn/optimizer/fused_lamb.py
+++ b/colossalai/nn/optimizer/fused_lamb.py
@@ -1,11 +1,9 @@
# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py
import torch
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
-@OPTIMIZERS.register_module
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py
index 468713b223c1..59a93a8be9c7 100644
--- a/colossalai/nn/optimizer/fused_sgd.py
+++ b/colossalai/nn/optimizer/fused_sgd.py
@@ -2,11 +2,9 @@
import torch
from torch.optim.optimizer import Optimizer, required
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
-@OPTIMIZERS.register_module
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py
index 84903ac36832..e08df410effe 100644
--- a/colossalai/nn/optimizer/hybrid_adam.py
+++ b/colossalai/nn/optimizer/hybrid_adam.py
@@ -4,13 +4,11 @@
from torch.optim import Adam
from colossalai.kernel.op_builder import FusedOptimBuilder
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam
-@OPTIMIZERS.register_module
class HybridAdam(CPUAdam):
"""Implements Adam algorithm.
diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py
index 399ad39b6658..d5de267f73ee 100644
--- a/colossalai/nn/optimizer/lamb.py
+++ b/colossalai/nn/optimizer/lamb.py
@@ -5,10 +5,7 @@
import torch
from torch.optim import Optimizer
-from colossalai.registry import OPTIMIZERS
-
-@OPTIMIZERS.register_module
class Lamb(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py
index 212f66671a0d..58393fdae4bf 100644
--- a/colossalai/nn/optimizer/lars.py
+++ b/colossalai/nn/optimizer/lars.py
@@ -5,10 +5,7 @@
import torch
from torch.optim import Optimizer
-from colossalai.registry import OPTIMIZERS
-
-@OPTIMIZERS.register_module
class Lars(Optimizer):
r"""Implements the LARS optimizer from `"Large batch training of convolutional networks"
`_.
@@ -22,28 +19,24 @@ class Lars(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""
- def __init__(
- self,
- params: Iterable[torch.nn.Parameter],
- lr=1e-3,
- momentum=0,
- eeta=1e-3,
- weight_decay=0,
- epsilon=0.0
- ) -> None:
+ def __init__(self,
+ params: Iterable[torch.nn.Parameter],
+ lr=1e-3,
+ momentum=0,
+ eeta=1e-3,
+ weight_decay=0,
+ epsilon=0.0) -> None:
if not isinstance(lr, float) or lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
- raise ValueError(
- "Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if eeta <= 0 or eeta > 1:
raise ValueError("Invalid eeta value: {}".format(eeta))
if epsilon < 0:
raise ValueError("Invalid epsilon value: {}".format(epsilon))
- defaults = dict(lr=lr, momentum=momentum,
- weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
+ defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
super().__init__(params, defaults)
@@ -76,11 +69,9 @@ def step(self, closure=None):
if lars:
w_norm = torch.norm(p)
g_norm = torch.norm(p.grad)
- trust_ratio = torch.where(
- w_norm > 0 and g_norm > 0,
- eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
- torch.ones_like(w_norm)
- )
+ trust_ratio = torch.where(w_norm > 0 and g_norm > 0,
+ eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
+ torch.ones_like(w_norm))
trust_ratio.clamp_(0.0, 50)
scaled_lr *= trust_ratio.item()
if weight_decay != 0:
@@ -90,8 +81,7 @@ def step(self, closure=None):
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
- buf = param_state['momentum_buffer'] = torch.clone(
- decayed_grad).detach()
+ buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(decayed_grad)
diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py
index af7a00b5c720..aed85cf91512 100644
--- a/colossalai/pipeline/p2p.py
+++ b/colossalai/pipeline/p2p.py
@@ -173,14 +173,10 @@ def recv_forward(self, prev_rank: int = None) -> Any:
Returns:
Any: The input tensor or input tensor list.
"""
- if self.stage_manager.is_first_stage():
- input_tensor = None
- else:
- if prev_rank is None:
- prev_rank = self.stage_manager.get_prev_rank()
- cur_rank = self.stage_manager.get_rank()
- input_tensor = _recv_object(prev_rank, cur_rank,
- self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
+ if prev_rank is None:
+ prev_rank = self.stage_manager.get_prev_rank()
+ cur_rank = self.stage_manager.get_rank()
+ input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
return input_tensor
@@ -193,14 +189,11 @@ def recv_backward(self, next_rank: int = None) -> Any:
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
- if self.stage_manager.is_last_stage():
- output_tensor_grad = None
- else:
- if next_rank is None:
- next_rank = self.stage_manager.get_next_rank()
- cur_rank = self.stage_manager.get_rank()
- output_tensor_grad = _recv_object(next_rank, cur_rank,
- self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
+ if next_rank is None:
+ next_rank = self.stage_manager.get_next_rank()
+ cur_rank = self.stage_manager.get_rank()
+ output_tensor_grad = _recv_object(next_rank, cur_rank,
+ self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
return output_tensor_grad
@@ -211,12 +204,10 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
- if not self.stage_manager.is_last_stage():
- if next_rank is None:
- next_rank = self.stage_manager.get_next_rank()
- cur_rank = self.stage_manager.get_rank()
- _send_object(output_object, cur_rank, next_rank,
- self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
+ if next_rank is None:
+ next_rank = self.stage_manager.get_next_rank()
+ cur_rank = self.stage_manager.get_rank()
+ _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
@@ -225,9 +216,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
- if not self.stage_manager.is_first_stage():
- if prev_rank is None:
- prev_rank = self.stage_manager.get_prev_rank()
- cur_rank = self.stage_manager.get_rank()
- _send_object(input_object, cur_rank, prev_rank,
- self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
+ if prev_rank is None:
+ prev_rank = self.stage_manager.get_prev_rank()
+ cur_rank = self.stage_manager.get_rank()
+ _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py
index 79913987b7cc..ba8b1591da9d 100644
--- a/colossalai/pipeline/pipelinable.py
+++ b/colossalai/pipeline/pipelinable.py
@@ -1,15 +1,24 @@
-import torch
import inspect
-from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
-from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \
- build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \
- call_module, customized_partition
-from colossalai.nn.layer.utils import CheckpointModule
-from colossalai.tensor import ColoParameter
-from colossalai.core import global_context as gpc
+import torch
+
from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.utils import CheckpointModule
+from colossalai.tensor import ColoParameter
+from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
+
from .layer_spec import LayerSpec
+from .utils import (
+ build_kwargs_for_function,
+ build_kwargs_for_module,
+ call_module,
+ customized_partition,
+ exec_func_with_kwargs,
+ exec_funcs_with_kwargs,
+ partition_balanced,
+ partition_uniform,
+)
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py
index 3ed9239272f1..583558551b3c 100644
--- a/colossalai/pipeline/schedule/_utils.py
+++ b/colossalai/pipeline/schedule/_utils.py
@@ -1,9 +1,59 @@
-from typing import Any, List, Optional
+from collections import OrderedDict
+from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
from torch.nn import Module
-from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
+from torch.utils._pytree import (
+ SUPPORTED_NODES,
+ LeafSpec,
+ TreeSpec,
+ _is_leaf,
+ _register_pytree_node,
+ tree_flatten,
+ tree_map,
+ tree_unflatten,
+)
+
+
+# this register are for torch under version 1.13.1, maybe removed in the future
+def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
+ return list(d.values()), list(d.keys())
+
+
+def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
+ return OrderedDict((key, value) for key, value in zip(context, values))
+
+
+_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
+
+
+def tree_map_hf(fn: Any, pytree: Any):
+ flat_args, spec = tree_flatten_hf(pytree)
+ return tree_unflatten([fn(i) for i in flat_args], spec)
+
+
+# use this flatten function to handle the ModelingOutput Class instance.
+def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
+ """Flattens a pytree into a list of values an a TreeSpec that can be used
+ to reconstruct the pytree.
+ """
+ if isinstance(pytree, OrderedDict):
+ node_type = OrderedDict
+ flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
+ child_pytrees, context = flatten_fn(pytree)
+
+ # Recursively flatten the children
+ result: List[Any] = []
+ children_specs: List['TreeSpec'] = []
+ for child in child_pytrees:
+ flat, child_spec = tree_flatten_hf(child)
+ result += flat
+ children_specs.append(child_spec)
+ return result, TreeSpec(node_type, context, children_specs)
+ else:
+ result, tree_spec = tree_flatten(pytree)
+ return result, tree_spec
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
@@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
return x
-def merge_batch(data: List[Any]) -> Any:
+def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.
Args:
@@ -118,12 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
flattened_data = []
tree_spec = None
for d in data:
- elems, tree_spec = tree_flatten(d)
+ # elems should be an instance of OrderedDict
+ elems, tree_spec = tree_flatten_hf(d)
flattened_data.append(elems)
merged_data = []
+
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
- merged_data.append(torch.cat(elem_batch, dim=0))
+ if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
+ merged_data.append(None)
+ else:
+ merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)
diff --git a/colossalai/pipeline/schedule/base.py b/colossalai/pipeline/schedule/base.py
index 9cd9beded65a..b0fa6e6ad2b8 100644
--- a/colossalai/pipeline/schedule/base.py
+++ b/colossalai/pipeline/schedule/base.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Iterable
+from typing import Any, Callable, Iterable, Optional
from torch import Tensor
from torch.nn import Module
@@ -14,18 +14,18 @@ def __init__(self, stage_manager: PipelineStageManager) -> None:
def forward_backward_step(self,
model: Module,
- optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[[Any, Any], Tensor],
+ optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Forward and backward step for pipeline training.
Args:
model (Module): Model to be trained.
- optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py
new file mode 100644
index 000000000000..6fdb09be5f32
--- /dev/null
+++ b/colossalai/pipeline/schedule/interleaved_pp.py
@@ -0,0 +1,372 @@
+from functools import partial
+from typing import Any, Callable, Iterable, List, Optional, Union
+
+import torch
+import torch.cuda
+from torch.nn import Module
+from torch.utils._pytree import tree_map
+
+from colossalai.interface import OptimizerWrapper
+from colossalai.pipeline.p2p import PipelineP2PCommunication
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.utils.cuda import get_current_device
+
+from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
+from .base import PipelineSchedule
+
+
+class InterleavedSchedule(PipelineSchedule):
+
+ def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
+ self.num_model_chunks = num_model_chunks
+ assert num_microbatches % self.num_model_chunks == 0, \
+ "Number of microbatches should be an integer multiple of number of model chunks"
+ super().__init__(stage_manager)
+ self.comm = PipelineP2PCommunication(stage_manager)
+ self.num_microbatches = num_microbatches
+ self.batch: Optional[Any] = None
+ self.batch_size: Optional[int] = None
+ self.microbatch_offset: Optional[int] = None
+ self.microbatch_size: Optional[int] = None
+
+ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
+ """Load a batch from data iterator.
+
+ Args:
+ data_iter (Iterable): Data iterator.
+ device (Optional[torch.device], optional): Target device. Defaults to None.
+ """
+ batch = next(data_iter)
+ if device is not None:
+ batch = tree_map(partial(to_device, device=device), batch)
+ self.batch = batch
+ self.batch_size = get_batch_size(batch)
+ self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
+ assert self.batch_size % self.num_microbatches == 0, \
+ "Batch size should divided by the number of microbatches"
+ self.microbatch_size = self.batch_size // self.num_microbatches
+
+ def load_micro_batch(self, model_chunk_id: int) -> Any:
+ """Load a micro batch from the current batch.
+
+ Args:
+ microbatch_id (int): the current model chunk idx.
+
+ Returns:
+ Any: Micro batch.
+ """
+ micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
+ self.microbatch_offset[model_chunk_id] += self.microbatch_size
+ return tree_map(partial(to_device, device=get_current_device()), micro_batch)
+
+ def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
+ """Helper method to get the model chunk ID given the iteration number.
+
+ Args:
+ microbatch_id (int): the current microbatch idx
+ forward (bool): if is the forward process
+
+ Returns:
+ int: The model chunk idx of the input microbatch_id
+ """
+ microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
+ model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
+ if not forward:
+ model_chunk_id = (self.num_model_chunks - model_chunk_id - 1)
+ return model_chunk_id
+
+ def is_first_stage(self, model_chunk_id: int) -> bool:
+ """Is the current virtual stage the first stage
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+
+ Returns:
+ bool: Whether the current virtual stage is the first stage.
+ """
+ if self.stage_manager.is_first_stage() and model_chunk_id == 0:
+ return True
+ return False
+
+ def is_last_stage(self, model_chunk_id: int) -> bool:
+ """Is the current virtual stage the last stage
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+
+ Returns:
+ bool: Whether the current virtual stage is the last stage.
+ """
+ if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
+ return True
+ return False
+
+ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
+ """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ prev_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input tensor or input tensor list.
+ """
+ if self.is_first_stage(model_chunk_id):
+ input_tensor = None
+ else:
+ input_tensor = self.comm.recv_forward(prev_rank)
+
+ return input_tensor
+
+ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
+ """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ next_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input gradient tensor or gradient tensor list.
+ """
+ if self.is_last_stage(model_chunk_id):
+ output_tensor_grad = None
+ else:
+ output_tensor_grad = self.comm.recv_backward(next_rank)
+
+ return output_tensor_grad
+
+ def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
+ """Sends the input tensor to the next stage in pipeline.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ output_object (Any): Object to be sent.
+ next_rank (int, optional): The rank of the recipient of the tensor.
+ """
+ if not self.is_last_stage(model_chunk_id):
+ self.comm.send_forward(output_object, next_rank)
+
+ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
+ """Sends the gradient tensor to the previous stage in pipeline.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the recipient of the tensor
+ """
+ if not self.is_first_stage(model_chunk_id):
+ self.comm.send_backward(input_object, prev_rank)
+
+ def forward_step(self,
+ model_chunk: Module,
+ model_chunk_id: int,
+ input_obj: Optional[dict],
+ criterion: Callable,
+ accum_loss: Optional[torch.Tensor] = None,
+ outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
+ """Forward one step of the pipeline
+ Args:
+ model (Module): Model Chunk to be run
+ input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
+ criterion (Callable): Criterion to calculate loss.
+ accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
+ outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
+ """
+ micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
+
+ # for the first stage, input_obj is None
+ # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
+ output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
+
+ if self.is_last_stage(model_chunk_id):
+ loss = criterion(output_obj, micro_batch) / self.num_microbatches
+ if accum_loss is not None:
+ accum_loss.add_(loss.detach())
+ if outputs is not None:
+ outputs.append(tree_map(detach, output_obj))
+ return loss
+ else:
+ return output_obj
+
+ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
+ output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
+ """Backward one step of the pipeline
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to update the model
+ input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
+ output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
+ output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
+
+ Returns:
+ Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
+ """
+
+ # Retain the grad on the input_obj.
+ tree_map(retain_grad, input_obj)
+
+ # Backward pass.
+ if output_obj_grad is None:
+ optimizer.backward(output_obj)
+ else:
+ if "backward_tensor_keys" not in output_obj:
+ for k, grad in output_obj_grad.items():
+ optimizer.backward_by_grad(output_obj[k], grad)
+ else:
+ for k, grad in output_obj_grad.items():
+ output_obj[k].grad = grad
+ for k in output_obj["backward_tensor_keys"]:
+ tensor_to_backward = output_obj[k]
+ optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
+
+ # Collect the grad of the input_obj.
+ input_obj_grad = None
+ if input_obj is not None:
+ input_obj_grad = {}
+ for k, v in input_obj.items():
+ if isinstance(v, torch.Tensor) and v.grad is not None:
+ input_obj_grad[k] = v.grad
+ return input_obj_grad
+
+ def forward_backward_step(self,
+ model_chunk: Module,
+ data_iter: Iterable,
+ criterion: Callable[..., Any],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = False,
+ return_outputs: bool = False) -> dict:
+ """Runs interleaved 1F1B schedule, with communication between pipeline stages.
+
+ Args:
+ model_chunk (List[Module]): Model Chunk to be trained.
+ data_iter (Iterable): Data iterator.
+ criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
+ return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
+ return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
+
+ Returns:
+ dict: A dict with keys: 'loss' and 'outputs'.
+ """
+ forward_only = not torch.is_grad_enabled()
+ if optimizer is None:
+ assert forward_only, "Optimizer should be passed when doing backward."
+
+ self.load_batch(data_iter)
+ num_model_chunks = len(model_chunk)
+
+ # num_warmup_microbatches is the step when not all the processes are working
+ num_microbatches = self.num_microbatches * num_model_chunks
+ if forward_only:
+ num_warmup_microbatches = num_microbatches
+ else:
+ num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
+ num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
+ num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
+
+ num_microbatches_remaining = num_microbatches - num_warmup_microbatches
+
+ # Input, output tensors only need to be saved when doing backward passes
+ input_objs = None
+ output_objs = None
+
+ if not forward_only:
+ input_objs = [[] for _ in range(num_model_chunks)]
+ output_objs = [[] for _ in range(num_model_chunks)]
+
+ outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
+
+ if return_loss and self.stage_manager.is_last_stage():
+ accum_loss = torch.zeros(1, device=get_current_device())
+ else:
+ accum_loss = None
+
+ # for ranks except the first one, get into recv state
+ # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
+ input_obj = self.recv_forward(0)
+ input_objs[0].append(input_obj)
+ # Run warmup forward passes.
+ for i in range(num_warmup_microbatches):
+ model_chunk_id = self.get_model_chunk_id(i, forward=True)
+
+ # recv first on first rank to avoid sending or recving at the same time
+ if self.stage_manager.is_first_stage():
+ input_obj = self.recv_forward(model_chunk_id)
+ output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
+ self.send_forward(model_chunk_id, output_obj)
+ if not forward_only:
+ input_objs[model_chunk_id].append(input_obj)
+ output_objs[model_chunk_id].append(output_obj)
+ else:
+ output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
+ if not forward_only:
+ output_objs[model_chunk_id].append(output_obj)
+ self.send_forward(model_chunk_id, output_obj)
+ if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
+ break
+ else:
+ model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
+
+ input_obj = self.recv_forward(model_chunk_id)
+ if not forward_only:
+ input_objs[model_chunk_id].append(input_obj)
+
+ # Run 1F1B in steady state.
+ for i in range(num_microbatches_remaining):
+ model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
+ last_iteration = (i == (num_microbatches_remaining - 1))
+
+ output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
+ if forward_only:
+ self.send_forward(model_chunk_id, output_obj)
+
+ if not last_iteration:
+ input_obj = self.recv_forward(model_chunk_id)
+
+ else:
+ self.send_forward(model_chunk_id, output_obj)
+ # Add input_obj and output_obj to end of list.
+ input_objs[model_chunk_id].append(input_obj)
+ output_objs[model_chunk_id].append(output_obj)
+
+ model_chunk_id = self.get_model_chunk_id(i, forward=False)
+ output_obj_grad = self.recv_backward(model_chunk_id)
+
+ # Pop output_obj and output_obj from the start of the list for
+ # the backward pass.
+ input_obj = input_objs[model_chunk_id].pop(0)
+ output_obj = output_objs[model_chunk_id].pop(0)
+
+ # backward
+ input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
+
+ if last_iteration:
+ input_obj = None
+ else:
+ model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
+ input_obj = self.recv_forward(model_chunk_id)
+ model_chunk_id = self.get_model_chunk_id(i, forward=False)
+ self.send_backward(model_chunk_id, input_obj_grad)
+
+ # Run cooldown backward passes.
+ if not forward_only:
+ for i in range(num_microbatches_remaining, num_microbatches):
+ model_chunk_id = self.get_model_chunk_id(i, forward=False)
+ # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
+ input_obj = input_objs[model_chunk_id].pop(0)
+ output_obj = output_objs[model_chunk_id].pop(0)
+
+ output_obj_grad = self.recv_backward(model_chunk_id)
+ input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
+ self.send_backward(model_chunk_id, input_obj_grad)
+
+ if outputs is not None:
+ outputs = merge_batch(outputs)
+ return {'loss': accum_loss, 'outputs': outputs}
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index ade3cf456fe3..fbd0f9f0d4c0 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -6,25 +6,47 @@
from torch.nn import Module
from torch.utils._pytree import tree_map
-from colossalai.interface import OptimizerWrapper
+from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
-from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
+from ._utils import (
+ detach,
+ get_batch_size,
+ get_micro_batch,
+ merge_batch,
+ model_forward,
+ retain_grad,
+ to_device,
+ tree_map_hf,
+)
from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule):
- def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None:
+ def __init__(self,
+ stage_manager: PipelineStageManager,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None) -> None:
+ """1F1B pipeline schedule.
+
+ Args:
+ stage_manager (PipelineStageManager): Pipeline stage manager
+ num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
+ microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
+ """
super().__init__(stage_manager)
+ assert num_microbatches is not None or microbatch_size is not None, \
+ "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
+ self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
- self.microbatch_size: Optional[int] = None
+ self._use_microbatch_size = num_microbatches is None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -39,9 +61,14 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
- assert self.batch_size % self.num_microbatches == 0, \
- "Batch size should divided by the number of microbatches"
- self.microbatch_size = self.batch_size // self.num_microbatches
+ if not self._use_microbatch_size:
+ assert self.batch_size % self.num_microbatches == 0, \
+ "Batch size should divided by the number of microbatches"
+ self.microbatch_size = self.batch_size // self.num_microbatches
+ else:
+ assert self.batch_size % self.microbatch_size == 0, \
+ "Batch size should divided by the microbatch size"
+ self.num_microbatches = self.batch_size // self.microbatch_size
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
@@ -53,6 +80,62 @@ def load_micro_batch(self) -> Any:
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
+ def recv_forward(self, prev_rank: int = None) -> Any:
+ """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
+ For 1F1B.
+
+ Args:
+ prev_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input tensor or input tensor list.
+ """
+ if self.stage_manager.is_first_stage():
+ input_tensor = None
+ else:
+ input_tensor = self.comm.recv_forward(prev_rank)
+
+ return input_tensor
+
+ def recv_backward(self, next_rank: int = None) -> Any:
+ """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
+ For 1F1B.
+
+ Args:
+ next_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input gradient tensor or gradient tensor list.
+ """
+ if self.stage_manager.is_last_stage():
+ output_tensor_grad = None
+ else:
+ output_tensor_grad = self.comm.recv_backward(next_rank)
+
+ return output_tensor_grad
+
+ def send_forward(self, output_object: Any, next_rank: int = None) -> None:
+ """Sends the input tensor to the next stage in pipeline.
+ For 1F1B.
+
+ Args:
+ output_object (Any): Object to be sent.
+ next_rank (int, optional): The rank of the recipient of the tensor.
+ """
+ if not self.stage_manager.is_last_stage():
+ self.comm.send_forward(output_object, next_rank)
+
+ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
+ """Sends the gradient tensor to the previous stage in pipeline.
+ For 1F1B.
+
+ Args:
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the recipient of the tensor
+ """
+ if not self.stage_manager.is_first_stage():
+ self.comm.send_backward(input_object, prev_rank)
+
def forward_step(self,
model: Module,
input_obj: Optional[dict],
@@ -72,16 +155,16 @@ def forward_step(self,
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch()
-
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
+
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
- outputs.append(tree_map(detach, output_obj))
+ outputs.append(tree_map_hf(detach, output_obj))
return loss
else:
return output_obj
@@ -102,7 +185,6 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
-
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
@@ -128,18 +210,18 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
def forward_backward_step(self,
model: Module,
- optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
+ optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model (Module): Model to be trained.
- optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
@@ -147,6 +229,8 @@ def forward_backward_step(self,
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
+ if optimizer is None:
+ assert forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter)
@@ -171,11 +255,11 @@ def forward_backward_step(self,
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
- input_obj = self.comm.recv_forward()
+ input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
- self.comm.send_forward(output_obj)
+ self.send_forward(output_obj)
if not forward_only:
input_objs.append(input_obj)
@@ -185,7 +269,7 @@ def forward_backward_step(self,
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
- input_obj = self.comm.recv_forward()
+ input_obj = self.recv_forward()
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
@@ -193,15 +277,15 @@ def forward_backward_step(self,
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
- self.comm.send_forward(output_obj)
+ self.send_forward(output_obj)
if not last_iteration:
- input_obj = self.comm.recv_forward()
+ input_obj = self.recv_forward()
else:
# TODO adjust here
- self.comm.send_forward(output_obj)
- output_obj_grad = self.comm.recv_backward()
+ self.send_forward(output_obj)
+ output_obj_grad = self.recv_backward()
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
@@ -216,8 +300,8 @@ def forward_backward_step(self,
if last_iteration:
input_obj = None
else:
- input_obj = self.comm.recv_forward()
- self.comm.send_backward(input_obj_grad)
+ input_obj = self.recv_forward()
+ self.send_backward(input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
@@ -225,10 +309,12 @@ def forward_backward_step(self,
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
- output_obj_grad = self.comm.recv_backward()
+ output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
- self.comm.send_backward(input_obj_grad)
+ self.send_backward(input_obj_grad)
if outputs is not None:
- outputs = merge_batch(outputs)
+ if isinstance(model, ModelWrapper):
+ model = model.unwrap()
+ outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs}
diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py
index fe228e2270dd..6ba7dc629958 100644
--- a/colossalai/pipeline/stage_manager.py
+++ b/colossalai/pipeline/stage_manager.py
@@ -17,28 +17,24 @@ class PipelineStageManager:
Attributes:
num_stages (int): Number of stages in the pipeline.
stage (int): The current stage.
- num_virtual_stages (int): Number of virtual stages in the pipeline.
- virtual_stage (int): The current virtual stage.
"""
- def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None:
+ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
- self.num_virtual_stages: Optional[int] = None
- self.virtual_stage: Optional[int] = None
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord
coord = self.pg_mesh.coordinate()
- if self.stage > 0:
- prev_coord = coord[: self.pipeline_axis] + \
- (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
- self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape)
- if self.stage < self.num_stages - 1:
- next_coord = coord[: self.pipeline_axis] + \
- (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
- self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape)
+ # the prev rank of rank0 is the last rank
+ prev_coord = coord[: self.pipeline_axis] + \
+ (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
+ self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap')
+ # the next rank of the last rank is rank0
+ next_coord = coord[: self.pipeline_axis] + \
+ (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
+ self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap')
# init p2p process groups
stages = list(range(self.num_stages))
@@ -48,32 +44,28 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
- def is_first_stage(self, virtual: bool = False) -> bool:
- """Is the current stage the first stage.
+ if is_virtual:
+ # add the process group of the first rank and the last rank
+ # only used in interleaved pipeline for now
+ group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
+ if self.stage in [stages[0], stages[-1]]:
+ ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
+ self.p2p_groups[tuple(ranks_in_group)] = group
- Args:
- virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
+ def is_first_stage(self) -> bool:
+ """Is the current stage the first stage.
Returns:
bool: Whether the current stage is the first stage.
"""
- if virtual:
- assert self.num_virtual_stages is not None
- return self.virtual_stage == 0
return self.stage == 0
- def is_last_stage(self, virtual: bool = False) -> bool:
+ def is_last_stage(self) -> bool:
"""Is the current stage the last stage.
- Args:
- virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
-
Returns:
bool: Whether the current stage is the last stage.
"""
- if virtual:
- assert self.num_virtual_stages is not None
- return self.virtual_stage == self.num_virtual_stages - 1
return self.stage == self.num_stages - 1
@property
@@ -108,7 +100,6 @@ def get_prev_rank(self) -> int:
Returns:
int: Rank of the previous stage.
"""
- assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
return self.prev_rank
def get_next_rank(self) -> int:
@@ -117,39 +108,8 @@ def get_next_rank(self) -> int:
Returns:
int: Rank of the next stage.
"""
- assert not self.is_last_stage(), "Cannot get next rank in the last stage."
return self.next_rank
- def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
- """Set the number of virtual stages.
-
- Args:
- num_virtual_stages (int): Number of virtual stages.
- """
- self.num_virtual_stages = num_virtual_stages
-
- def set_virtual_stage(self, virtual_stage: int) -> None:
- """Set the virtual stage.
-
- Args:
- virtual_stage (int): Virtual stage.
- """
- self.virtual_stage = virtual_stage
-
- @contextmanager
- def switch_virtual_stage(self, virtual_stage: int) -> None:
- """A context manager to switch virtual stage.
-
- Args:
- virtual_stage (int): Target virtual stage.
- """
- old_stage = self.virtual_stage
- try:
- self.set_virtual_stage(virtual_stage)
- yield
- finally:
- self.set_virtual_stage(old_stage)
-
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py
index ac8a3ad7d1db..be8428692756 100644
--- a/colossalai/pipeline/utils.py
+++ b/colossalai/pipeline/utils.py
@@ -1,12 +1,13 @@
import heapq
import inspect
+from collections import OrderedDict
+from typing import List
+
import torch
+from colossalai.legacy.nn.layer.utils import CheckpointModule
from colossalai.logging import get_dist_logger
-from colossalai.nn.layer.utils import CheckpointModule
-from typing import List
-from collections import OrderedDict
def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start
@@ -162,7 +163,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
kwargs_offset = 1
elif isinstance(input_tensor, (tuple, OrderedDict)):
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
- # Huggingface will take their own structures based on OrderedDict as the output
+ # Huggingface will take their own structures based on OrderedDict as the output
# between layers so we've to close this check.
kwargs_offset = len(input_tensor)
args_name_list = list(sig.parameters.keys())
@@ -256,7 +257,7 @@ def call_module(module, args=None, kwargs=None):
def customized_partition(exec_seq):
'''
- This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
+ This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
annotation to note the partition point.
'''
customized_parts = {}
diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index 7dc15f0a0635..4bd7d5208a64 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -30,28 +30,59 @@
### Quick Start
-The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.):
+The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization):
```python
-from colossalai.shardformer import ShardConfig, Shard
+from colossalai.shardformer import ShardConfig, ShardFormer
from transformers import BertForMaskedLM
+import colossalai
# launch colossalai
-colossalai.launch_from_torch()
+colossalai.launch_from_torch(config={})
# create model
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
# create huggingface model as normal
-shard_config = ShardConfig()
+shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
+ pipeline_stage_manager=stage_manager,
+ enable_tensor_parallelism=True,
+ enable_fused_normalization=True,
+ enable_flash_attention=True,
+ enable_jit_fused=True,
+ enable_sequence_parallelism=True,
+ enable_sequence_overlap=True)
+
shard_former = ShardFormer(shard_config=shard_config)
-sharded_model = shard_former.optimize(model).to('cuda')
+sharded_model, shared_params = shard_former.optimize(model).to('cuda')
# do everything like normal
...
```
+Following are the description `ShardConfig`'s arguments:
+
+- `tensor_parallel_process_group`: The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group.
+
+- `pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
+
+- `enable_tensor_parallelism`: Whether to use tensor parallelism. Defaults to True.
+
+- `enable_fused_normalization`: Whether to use fused layernorm. Defaults to False.
+
+- `enable_flash_attention`: Whether to switch on flash attention. Defaults to False.
+
+- `enable_jit_fused`: Whether to switch on JIT fused operators. Defaults to False.
+
+- `enable_sequence_parallelism`: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
+
+- `enable_sequence_overlap`: Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False.
+
+- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
+
+- `inference_only`: Whether only doing forward passing. Defaults to False.
+
### Write your own policy
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
@@ -82,44 +113,30 @@ We will follow this roadmap to develop Shardformer:
- [x] API Implementation
- [x] Unit Testing
- [ ] Policy Implementation
- - [ ] Hugging Face
- - [ ] NLP
- - [x] BERT
- - [x] T5
- - [x] LlaMa
- - [x] GPT2
- - [x] OPT
- - [x] BLOOM
- - [ ] GLM
- - [ ] RoBERTa
- - [ ] ALBERT
- - [ ] ERNIE
- - [ ] GPT Neo
- - [ ] GPT-J
- - [ ] CV
- - [x] ViT
- - [ ] BEiT
- - [ ] SwinTransformer
- - [ ] SwinTransformer V2
- - [ ] Audio
- - [x] Whisper
- - [ ] Multi-modal
- - [x] SAM
- - [x] BLIP-2
-- [ ] Flash Attention Support
- - [ ] NLP
- - [x] BERT
- - [x] T5
- - [x] LlaMa
- - [x] GPT2
- - [x] OPT
- - [x] BLOOM
- - [ ] GLM
- - [ ] RoBERTa
- - [ ] ALBERT
- - [ ] ERNIE
- - [ ] GPT Neo
- - [ ] GPT-J
+
+| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
+| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
+| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] |
+| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+
## 💡 API Design
@@ -286,41 +303,36 @@ class ShardFormer:
Example:
+ org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
+ shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
- shard_former.init_distributed()
- model = shard_former.optimize(model, policy=policy)
- dataloader = shard_former.shard_dataset(dataset)
+ model, shared_params = shard_former.optimize(org_model)
"""
def __init__(self, shard_config: ShardConfig):
"""
Do two things:
- 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
+ 1. Create a distribute coordinator
2. serve as a store for shard config
"""
self.shard_config = shard_config
- self.pg_manager = None
+ self.coordinator = DistCoordinator()
- def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
- """
- Initialize the distributed process group according to the
- """
- pg_manager = ...
- self.pg_manager = pg_manager
- return pg_manager
+ def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
+ r"""
+ This method will optimize the model based on the given policy.
- def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module:
- """
- Shard model for TP and PP
- """
- ...
+ Args:
+ model (`torch.nn.Model`): the origin huggingface model
+ shard_config (`ShardConfig`): the config for distribute information
+ policy (`Policy`): the custom policy for sharding
- def shard_dataset(self, dataset: Dataset) -> Dataloader:
- """
- Shard dataset for DP
+ Returns: the sharded model and the shared parameters
"""
- ...
+ sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
+ shared_params = sharder.shard()
+ return model, shared_params
```
## ⌨️ Development Notes
@@ -429,12 +441,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate
### Convergence
-To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
+To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
-| accuracy | f1 | loss | GPU number | model shard |
+the configurations are as follows:
+```python
+batch_size = 2
+epoch = 3
+lr = 2.4e-5
+accumulation_steps = 8
+warmup_fraction = 0.03
+```
+
+
+
+| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
-| 0.82594 | 0.87441 | 0.09913 | 4 | True |
-| 0.81884 | 0.87299 | 0.10120 | 2 | True |
-| 0.81855 | 0.87124 | 0.10357 | 1 | False |
+| 0.82971 | 0.87713 | 0.23194 | 4 | True |
+| 0.83797 | 0.88006 | 0.22683 | 2 | True |
+| 0.84521 | 0.88700 | 0.21822 | 1 | False |
+
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py
index de82305b2547..81be2017855c 100644
--- a/colossalai/shardformer/examples/convergence_benchmark.py
+++ b/colossalai/shardformer/examples/convergence_benchmark.py
@@ -49,9 +49,12 @@ def train(args):
# if multiple GPUs, shard the model
if dist.get_world_size() > 1:
- shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
+ tp_group = dist.new_group(backend='nccl')
+ shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
+ enable_tensor_parallelism=True,
+ enable_all_optimization=True)
shard_former = ShardFormer(shard_config=shard_config)
- model = shard_former.optimize(model)
+ model, _ = shard_former.optimize(model)
optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh
index 1c281abcda6d..22f13a7cf827 100644
--- a/colossalai/shardformer/examples/convergence_benchmark.sh
+++ b/colossalai/shardformer/examples/convergence_benchmark.sh
@@ -1,7 +1,7 @@
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
- --max_epochs 1 \
+ --max_epochs 3 \
--batch_size 2 \
--lr 2.4e-5 \
--fused_layernorm False \
diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py
index 9c7b76bcf0a6..2f186709d946 100644
--- a/colossalai/shardformer/examples/performance_benchmark.py
+++ b/colossalai/shardformer/examples/performance_benchmark.py
@@ -29,7 +29,8 @@ def data_gen_for_sequence_classification(batch_size, seq_length):
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
- num_labels=16)
+ num_labels=16,
+ pad_token_id=2)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
@@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
- sharded_model = shard_former.optimize(model).cuda()
+ sharded_model, _ = shard_former.optimize(model)
+ sharded_model = sharded_model.cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py
index 7e97bee01b33..45b305733813 100644
--- a/colossalai/shardformer/layer/_operation.py
+++ b/colossalai/shardformer/layer/_operation.py
@@ -1,3 +1,5 @@
+from typing import Any
+
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -141,6 +143,240 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None
+class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
+ """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
+
+ Args:
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
+ overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+ ctx.process_group = process_group
+ ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
+ ctx.dim = dim
+ ctx.overlap = overlap
+
+ input_parallel = _gather(input_, dim, process_group)
+
+ if bias is not None:
+ output = F.linear(input_parallel, weight, bias)
+ else:
+ output = F.linear(input_parallel, weight)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ dim = ctx.dim
+ process_group = ctx.process_group
+ overlap = ctx.overlap
+
+ if not overlap:
+ input_parallel = _gather(input_, dim, process_group)
+
+ total_input = input_parallel
+ grad_input = grad_output.matmul(weight)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ total_input = total_input.view(-1, total_input.shape[-1])
+
+ if ctx.async_grad_reduce_scatter:
+ # Asynchronous reduce-scatter
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_parallel.dtype,
+ device=input_parallel.device).contiguous()
+ handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ # Delay the start of weight gradient computation shortly (3us) to have
+ # reduce-scatter scheduled first and have GPU resources allocated
+ _ = torch.empty(1, device=grad_output.device) + 1
+
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.async_grad_reduce_scatter:
+ handle.wait()
+
+ else:
+ input_ = input_.contiguous()
+ world_size = dist.get_world_size(process_group)
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+
+ # do all gather in is async way
+ gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
+ # calculate gradient and prepare data asynchronously with all-gather
+ # calculate
+ grad_input = grad_output.matmul(weight)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+ # prepare data
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
+ # wait until all-gather finished
+ gather_handle.wait()
+
+ # do reduce-scatter in async way
+ reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
+ # calculate gradient
+ if len(input_parallel.shape) > 2:
+ input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
+ grad_weight = grad_output.t().matmul(input_parallel)
+ # wait until reduce-scatter finished
+ reducescatter_handle.wait()
+
+ return output, grad_weight, grad_bias, None, None, None, None
+
+
+class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
+ """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
+
+ Args:
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, process_group, dim):
+ ctx.dim = dim
+ ctx.process_group = process_group
+
+ # do reduce-scatter
+ new_shape = list(input_.shape)
+ assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
+ f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
+ new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
+ input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
+ output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
+ dist.reduce_scatter(output, input_list, group=process_group)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ dim = ctx.dim
+ process_group = ctx.process_group
+
+ return _gather(grad_output, dim, process_group), None, None
+
+
+class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
+ """
+ This class is designed for matmul operation with gather forward and reduce-scatter backward.
+
+ Args:
+ input_ (`torch.Tensor`): input matrix.
+ dim (int): the dimension to perform split and gather
+ process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+ ctx.process_group = process_group
+ ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
+ ctx.dim = dim
+ ctx.overlap = overlap
+
+ input_parallel = _gather(input_, dim, process_group)
+
+ output = torch.matmul(input_parallel, weight)
+
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ dim = ctx.dim
+ process_group = ctx.process_group
+ overlap = ctx.overlap
+
+ if not overlap:
+ input_parallel = _gather(input_, dim, process_group)
+
+ total_input = input_parallel
+ grad_input = grad_output.matmul(weight.T)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ total_input = total_input.view(-1, total_input.shape[-1])
+
+ if ctx.async_grad_reduce_scatter:
+ # Asynchronous reduce-scatter
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_parallel.dtype,
+ device=input_parallel.device).contiguous()
+ handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ # Delay the start of weight gradient computation shortly (3us) to have
+ # reduce-scatter scheduled first and have GPU resources allocated
+ _ = torch.empty(1, device=grad_output.device) + 1
+
+ grad_weight = total_input.t().matmul(grad_output)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.async_grad_reduce_scatter:
+ handle.wait()
+
+ else:
+ world_size = dist.get_world_size(process_group)
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+
+ # do all gather in is async way
+ gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
+ # calculate gradient and prepare data asynchronously with all-gather
+ # calculate
+ grad_input = grad_output.matmul(weight.T)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+ # prepare data
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
+ # wait until all-gather finished
+ gather_handle.wait()
+
+ # do reduce-scatter in async way
+ reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
+ # calculate gradient
+ if len(input_parallel.shape) > 2:
+ input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
+ grad_weight = input_parallel.t().matmul(grad_output)
+ # wait until reduce-scatter finished
+ reducescatter_handle.wait()
+
+ return output, grad_weight, grad_bias, None, None, None, None
+
+
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
@@ -200,6 +436,26 @@ def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None
+class _GatherForwardSplitBackward(torch.autograd.Function):
+ """Gather the input from model parallel region and concatenate.
+
+ Args:
+ input_: input matrix.
+ parallel_mode: parallel mode.
+ dim: dimension
+ """
+
+ @staticmethod
+ def forward(ctx, input_, dim, process_group):
+ ctx.process_group = process_group
+ ctx.dim = dim
+ return _gather(input_, dim, process_group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split(grad_output, ctx.dim, ctx.process_group), None, None
+
+
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
@@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None):
return input_
# all gather
- rank = dist.get_rank(process_group)
+ input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
- tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat
@@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None):
return output
-class _GatherForwardSplitBackward(torch.autograd.Function):
- """Gather the input from model parallel region and concatenate.
+def _reduce_scatter(input_, dim=1, process_group=None):
+ """ Do reduce-scatter operation.
Args:
- input_: input matrix.
- parallel_mode: parallel mode.
- dim: dimension
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ dim (int): The dimension to perform reduce-scatter.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
+ world_size = dist.get_world_size(process_group)
+ if world_size == 1:
+ return input_
- @staticmethod
- def forward(ctx, input_, dim, process_group):
- ctx.process_group = process_group
- ctx.dim = dim
- return _gather(input_, dim, process_group)
+ # reduce-scatter
+ new_shape = list(input_.shape)
+ assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
+ f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
+ new_shape[dim] = new_shape[dim] // world_size
+ output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
+ dist.reduce_scatter(output, input_, group=process_group)
- @staticmethod
- def backward(ctx, grad_output):
- return _split(grad_output, ctx.dim, ctx.process_group), None, None
+ return output
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
@@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
+def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
+ overlap):
+ return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
+ async_grad_reduce_scatter, dim, overlap)
+
+
+def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
+ return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
+
+
+def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
+ overlap):
+ return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
+ async_grad_reduce_scatter, dim, overlap)
+
+
def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index d59b68ce4480..111d51b3f8d8 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -24,6 +24,8 @@
from ._operation import (
gather_forward_split_backward,
+ linear_gather_forward_reducescatter_backward,
+ linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
split_forward_gather_backward,
@@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule):
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
+ overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@@ -69,6 +73,9 @@ def __init__(self,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
+ seq_parallel: bool = False,
+ seq_parallel_dim: int = 1,
+ overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@@ -80,6 +87,9 @@ def __init__(self,
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
+ self.seq_parallel = seq_parallel
+ self.seq_parallel_dim = seq_parallel_dim
+ self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@@ -180,7 +190,12 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
+ if self.seq_parallel:
+ output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
+ self.process_group, True,
+ self.seq_parallel_dim, self.overlap)
+ else:
+ output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
@@ -203,6 +218,8 @@ class Linear1D_Row(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
+ process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
@@ -221,6 +238,8 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
+ seq_parallel: bool = False,
+ seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@@ -238,6 +257,8 @@ def __init__(self,
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
+ self.seq_parallel = seq_parallel
+ self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
@@ -373,7 +394,11 @@ def forward(self, input_: Tensor) -> Tensor:
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
- output = reduce_forward(output_parallel, self.process_group)
+ if self.seq_parallel:
+ output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
+ self.seq_parallel_dim)
+ else:
+ output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py
index bda147b121ab..4f391920e29b 100644
--- a/colossalai/shardformer/layer/parallel_module.py
+++ b/colossalai/shardformer/layer/parallel_module.py
@@ -10,6 +10,7 @@
from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
+from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
@@ -56,13 +57,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
for name, param in self._parameters.items():
if param is not None:
- param_ = param if keep_vars else param.detach()
- if is_distributed_tensor(param_):
- destination[prefix + name] = to_global(param_)
- elif is_customized_distributed_tensor(param_):
- destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
- else:
- destination[prefix + name] = param_
+ destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index df942d43ee2d..5ce77805f9b8 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -25,7 +25,9 @@
from ._operation import (
gather_forward_split_backward,
+ linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
+ matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
@@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
@@ -173,6 +176,8 @@ def __init__(self,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
+ seq_parallel: bool = False,
+ overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
@@ -185,6 +190,8 @@ def __init__(self,
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
+ self.seq_parallel = seq_parallel
+ self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
@@ -296,15 +303,19 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
- # Set up backprop all-reduce.
- input_parallel = reduce_backward(input_, self.process_group)
- # input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
- self.async_communication)
+ if self.seq_parallel:
+ input_parallel = input_
+ output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
+ self.process_group, True, 1, self.overlap)
+ else:
+ # Set up backprop all-reduce.
+ input_parallel = reduce_backward(input_, self.process_group)
+ output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
+ self.async_communication)
if self.gather_output:
# All-gather across the partitions.
@@ -329,6 +340,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
@@ -346,6 +358,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
+ seq_parallel: bool = False,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@@ -363,6 +376,7 @@ def __init__(self,
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
+ self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
@@ -499,7 +513,10 @@ def forward(self, input_: Tensor) -> Tensor:
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = torch.matmul(input_, self.weight)
- output = reduce_forward(output_parallel, self.process_group)
+ if self.seq_parallel:
+ output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
+ else:
+ output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index 5bd1c531cc68..30855a622adb 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -1,6 +1,6 @@
import math
import warnings
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -29,6 +29,8 @@
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
class BertPipelineForwards:
@@ -56,6 +58,7 @@ def bert_model_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
# TODO(jianghai): add explaination of the output here.
r"""
@@ -177,6 +180,17 @@ def bert_model_forward(
start_idx, end_idx = stage_index[0], stage_index[1]
# layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ if shard_config is not None and shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = split_forward_gather_backward(
+ encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
+
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0:
encoder_attention_mask = encoder_extended_attention_mask
@@ -223,11 +237,17 @@ def custom_forward(*inputs):
all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],)
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ if shard_config is not None and shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# end of a stage loop
- sequence_output = layer_outputs[0] if layer_outputs is not None else None
+ sequence_output = hidden_states if hidden_states is not None else None
if stage_manager.is_last_stage():
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
@@ -268,6 +288,7 @@ def bert_for_pretraining_forward(
hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
logger = logging.get_logger(__name__)
@@ -294,6 +315,7 @@ def bert_for_pretraining_forward(
stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index,
+ shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@@ -350,6 +372,7 @@ def bert_lm_head_model_forward(
hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -404,7 +427,8 @@ def bert_lm_head_model_forward(
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
@@ -457,6 +481,7 @@ def bert_for_masked_lm_forward(
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -491,6 +516,7 @@ def bert_for_masked_lm_forward(
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
+ shard_config=shard_config,
)
if stage_manager.is_last_stage():
@@ -532,6 +558,7 @@ def bert_for_next_sentence_prediction_forward(
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
**kwargs,
):
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
@@ -594,7 +621,8 @@ def bert_for_next_sentence_prediction_forward(
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
@@ -636,6 +664,7 @@ def bert_for_sequence_classification_forward(
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -666,7 +695,8 @@ def bert_for_sequence_classification_forward(
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
@@ -726,6 +756,7 @@ def bert_for_token_classification_forward(
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -742,21 +773,20 @@ def bert_for_token_classification_forward(
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
- outputs = BertPipelineForwards.bert_model_forward(
- self.bert,
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- hidden_states=hidden_states,
- stage_manager=stage_manager,
- stage_index=stage_index,
- )
+ outputs = BertPipelineForwards.bert_model_forward(self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
@@ -799,6 +829,7 @@ def bert_for_multiple_choice_forward(
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -843,6 +874,7 @@ def bert_for_multiple_choice_forward(
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
+ shard_config=shard_config,
)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
@@ -886,6 +918,7 @@ def bert_for_question_answering_forward(
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
# NOTE: the arg start_position and end_position are used only for the last stage
r"""
@@ -909,21 +942,20 @@ def bert_for_question_answering_forward(
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
- outputs = BertPipelineForwards.bert_model_forward(
- self.bert,
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- hidden_states=hidden_states,
- stage_manager=stage_manager,
- stage_index=stage_index,
- )
+ outputs = BertPipelineForwards.bert_model_forward(self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
@@ -1101,3 +1133,153 @@ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.T
return hidden_states
return forward
+
+
+def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ embedding_output = split_forward_gather_backward(embedding_output,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = split_forward_gather_backward(
+ encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ sequence_output = gather_forward_split_backward(sequence_output,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index 12276635ecfa..66f24dc6088b 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -23,6 +23,10 @@
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
+from colossalai.shardformer.shard import ShardConfig
+
+logger = logging.get_logger(__name__)
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
@@ -111,6 +115,7 @@ def bloom_model_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']:
@@ -205,6 +210,13 @@ def bloom_model_forward(
past_key_values_length=past_key_values_length,
)
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]),
start=start_idx):
@@ -248,6 +260,12 @@ def custom_forward(*inputs):
all_self_attentions = all_self_attentions + \
(outputs[2 if use_cache else 1],)
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
if stage_manager.is_last_stage():
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
@@ -287,6 +305,7 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
**deprecated_arguments):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -327,7 +346,8 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
@@ -380,6 +400,7 @@ def bloom_for_sequence_classification_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
**deprecated_arguments,
):
r"""
@@ -424,6 +445,7 @@ def bloom_for_sequence_classification_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@@ -503,6 +525,7 @@ def bloom_for_token_classification_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
**deprecated_arguments,
):
r"""
@@ -547,6 +570,7 @@ def bloom_for_token_classification_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@@ -597,6 +621,7 @@ def bloom_for_question_answering_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -632,6 +657,7 @@ def bloom_for_question_answering_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@@ -700,8 +726,7 @@ def forward(
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
- batch_size, tgt_len, _ = hidden_states.size()
- assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
+ batch_size, tgt_len, _ = query_layer.size()
_, kv_length, _, _ = key_layer.size()
@@ -896,3 +921,156 @@ def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
return self.bloom_gelu_forward(x, bias)
return forward
+
+
+def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ from transformers import BloomModel
+
+ def forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ if past_key_values[0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
+
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ causal_mask,
+ layer_past,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm2.py
similarity index 69%
rename from colossalai/shardformer/modeling/chatglm.py
rename to colossalai/shardformer/modeling/chatglm2.py
index 409e2e1f5497..16dcf87c8cfc 100644
--- a/colossalai/shardformer/modeling/chatglm.py
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -9,6 +9,8 @@
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
@@ -146,6 +148,7 @@ def chatglm_model_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
):
logger = logging.get_logger(__name__)
output_hidden_states = (output_hidden_states
@@ -198,6 +201,11 @@ def chatglm_model_forward(
all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1]
+
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
@@ -214,6 +222,11 @@ def chatglm_model_forward(
hidden_states, kv_cache = layer_ret
if use_cache:
presents = presents + (kv_cache,)
+
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
@@ -233,23 +246,22 @@ def chatglm_model_forward(
return {'hidden_states': hidden_states}
@staticmethod
- def chatglm_for_conditional_generation_forward(
- self: ChatGLMForConditionalGeneration,
- input_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- return_last_logit: Optional[bool] = False,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ):
+ def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_last_logit: Optional[bool] = False,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None):
logger = logging.get_logger(__name__)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
@@ -266,6 +278,7 @@ def chatglm_for_conditional_generation_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
+ shard_config=shard_config,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
@@ -296,3 +309,91 @@ def chatglm_for_conditional_generation_forward(
)
else:
return transformer_outputs
+
+
+def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ def forward(
+ self,
+ input_ids,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
+
+ batch_size, seq_length = input_ids.shape
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(
+ batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype,
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ [
+ attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask,
+ ],
+ dim=-1,
+ )
+
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
+
+ # Rotary positional embeddings
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
+ if position_ids is not None:
+ rotary_pos_emb = rotary_pos_emb[position_ids]
+ else:
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
+
+ # Run encoder.
+ # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
+ inputs_embeds = split_forward_gather_backward(inputs_embeds,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
+ inputs_embeds,
+ full_attention_mask,
+ rotary_pos_emb=rotary_pos_emb,
+ kv_caches=past_key_values,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ )
+
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ if not return_dict:
+ return tuple(v for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 47835d5d5468..84deafefeadd 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -21,6 +21,8 @@
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
+from colossalai.shardformer.shard import ShardConfig
class GPT2PipelineForwards:
@@ -47,7 +49,8 @@ def gpt2_model_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
@@ -75,9 +78,9 @@ def gpt2_model_forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
input_shape = input_ids.size()
- input_ids = input_ids.view(-1, seq_length)
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
@@ -86,13 +89,14 @@ def gpt2_model_forward(
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
- token_type_ids = token_type_ids.view(-1, seq_length)
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
- batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
+ hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
+ batch_size = hidden_states.shape[0]
# GPT2Attention mask.
if attention_mask is not None:
@@ -133,9 +137,9 @@ def gpt2_model_forward(
if stage_manager.is_first_stage():
if position_ids is not None:
- position_ids = position_ids.view(-1, seq_length)
+ position_ids = position_ids.view(-1, input_shape[-1])
else:
- position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
+ position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if inputs_embeds is None:
@@ -145,7 +149,7 @@ def gpt2_model_forward(
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
- hidden_states = self.drop(hidden_states)
+ hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
@@ -159,6 +163,13 @@ def gpt2_model_forward(
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx):
@@ -212,6 +223,12 @@ def custom_forward(*inputs):
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
@@ -257,7 +274,8 @@ def gpt2_lmhead_model_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -285,7 +303,8 @@ def gpt2_lmhead_model_forward(
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@@ -335,7 +354,8 @@ def gpt2_double_heads_model_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
@@ -367,7 +387,8 @@ def gpt2_double_heads_model_forward(
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@@ -421,7 +442,8 @@ def gpt2_for_question_answering_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
@@ -449,7 +471,8 @@ def gpt2_for_question_answering_forward(
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@@ -508,7 +531,8 @@ def gpt2_for_token_classification_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -534,7 +558,8 @@ def gpt2_for_token_classification_forward(
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@@ -578,7 +603,8 @@ def gpt2_for_sequence_classification_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -613,7 +639,8 @@ def gpt2_for_sequence_classification_forward(
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
- stage_index=stage_index)
+ stage_index=stage_index,
+ shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@@ -695,8 +722,6 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
- _, tgt_len, _ = hidden_states.size()
- assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
@@ -753,3 +778,211 @@ def forward(
return outputs
return forward
+
+
+def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ if position_ids is None:
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ # GPT2Attention mask.
+ if attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger = logging.get_logger(__name__)
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ hidden_states = self.ln_f(hidden_states)
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index f1d2998bbee4..ff622c306c59 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Callable, List, Optional, Tuple
import torch
@@ -19,6 +20,7 @@ class LlamaPipelineForwards:
under pipeline setting.
'''
+ @staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
@@ -169,6 +171,7 @@ def custom_forward(*inputs):
# always return dict for imediate stage
return {'hidden_states': hidden_states}
+ @staticmethod
def llama_for_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
@@ -276,6 +279,7 @@ def llama_for_causal_lm_forward(
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
+ @staticmethod
def llama_for_sequence_classification_forward(
self: LlamaForSequenceClassification,
input_ids: torch.LongTensor = None,
@@ -389,9 +393,18 @@ def llama_for_sequence_classification_forward(
def get_llama_flash_attention_forward():
+
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
+ llama_version = 2
+ try:
+ from transformers.models.llama.modeling_llama import repeat_kv
+ except:
+ warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
+ llama_version = 1
+
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
@@ -415,6 +428,7 @@ def forward(
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -424,6 +438,11 @@ def forward(
past_key_value = (key_states, value_states) if use_cache else None
+ # repeat k/v heads if n_kv_heads < n_heads
+ if llama_version == 2:
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index b4251f33b457..ad088f3702e5 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -518,7 +518,6 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
- assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
index 9fc0b7488803..2ce52163ac32 100644
--- a/colossalai/shardformer/modeling/vit.py
+++ b/colossalai/shardformer/modeling/vit.py
@@ -1,9 +1,9 @@
-import logging
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
+from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -72,18 +72,17 @@ def pp_forward(
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if output_attentions is not None:
- logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
- output_attentions = None
- if output_hidden_states is not None:
- logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
- output_hidden_states = None
+ logger = logging.get_logger(__name__)
+
+ # Preprocess passed in arguments
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
index 0a16c6f788da..62f8f7b4763e 100644
--- a/colossalai/shardformer/modeling/whisper.py
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -1,7 +1,26 @@
-from typing import Optional, Tuple
+import logging
+import random
+from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+ SequenceClassifierOutput,
+)
+from transformers.models.whisper.modeling_whisper import (
+ WhisperEncoder,
+ WhisperForAudioClassification,
+ WhisperForConditionalGeneration,
+ WhisperModel,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward():
@@ -247,3 +266,697 @@ def forward(
return outputs
return forward
+
+
+class WhisperPipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of Llama models
+ under pipeline setting.
+ '''
+
+ @staticmethod
+ def whisper_encoder_forward(
+ self: WhisperEncoder,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_states=None,
+ all_attentions=None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ Args:
+ input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
+ Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
+ `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
+ and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+ attention_mask (`torch.Tensor`)`, *optional*):
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ logger = logging.get_logger(__name__)
+
+ stage = stage_manager.stage
+ at_first_stage = (stage == 0)
+ at_last_stage = (stage == decoder_starting_stage - 1)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Process inputs if at the first stage of encoder.
+ if at_first_stage:
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
+
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
+ embed_pos = self.embed_positions.weight
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ assert head_mask.size()[0] == (
+ len(self.layers)
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+
+ else:
+ if hidden_states is None:
+ raise ValueError(
+ "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ for idx in range(start_idx, end_idx):
+ encoder_layer = self.layers[idx]
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ None,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ None,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if at_last_stage:
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions)
+
+ else:
+ return {'hidden_states': hidden_states, 'head_mask': head_mask}
+
+ @staticmethod
+ def whisper_decoder_forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
+ on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ logger = logging.get_logger(__name__)
+ stage = stage_manager.stage
+ at_first_stage = (stage == decoder_starting_stage)
+ at_last_stage = (stage == stage_manager.num_stages - 1)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}.")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if at_first_stage:
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if input_ids is not None:
+ positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
+ else:
+ positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
+ past_key_values_length)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ )
+ use_cache = False
+
+ else:
+
+ if hidden_states is None:
+ raise ValueError(
+ "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
+ input_shape = hidden_states.size()[:-1]
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states,
+ past_key_values_length)
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ for idx in range(start_idx, end_idx):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ decoder_layer = self.layers[idx]
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, use_cache)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ None, # encoder attention mask
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None, # past_key_value
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx]
+ if cross_attn_head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ if at_last_stage:
+ hidden_states = self.layer_norm(hidden_states)
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None)
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+ else:
+ return {
+ 'head_mask': head_mask,
+ 'cross_attn_head_mask': cross_attn_head_mask,
+ 'hidden_states': hidden_states,
+ }
+
+ @staticmethod
+ def whisper_model_forward(
+ self: WhisperModel,
+ input_features: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, WhisperModel
+ >>> from datasets import load_dataset
+
+ >>> model = WhisperModel.from_pretrained("openai/whisper-base")
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_features = inputs.input_features
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
+ >>> list(last_hidden_state.shape)
+ [1, 2, 512]
+ ```"""
+ # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ in_decoder = stage_manager.stage >= decoder_starting_stage
+ if not in_decoder:
+ if encoder_outputs is None:
+ input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
+
+ encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
+ self.encoder,
+ input_features,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+
+ if stage_manager.stage == decoder_starting_stage - 1:
+ # last stage of encoder
+ return {'encoder_hidden_states': encoder_outputs[0]}
+ else:
+ return encoder_outputs
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ at_last_decoder_stage = stage_manager.is_last_stage()
+ at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
+ if encoder_outputs is not None:
+ encoder_hidden_states = encoder_outputs[0]
+ elif encoder_hidden_states is None:
+ raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
+
+ if not at_first_decoder_stage and hidden_states is None:
+ raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder,
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+
+ # Directly return outputs of overloaded Whisper forward if not at last stage.
+ if not at_last_decoder_stage:
+ # encoder_hidden_states should be passed to the next stage
+ decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
+ return decoder_outputs
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_hidden_states,
+ )
+
+ @staticmethod
+ def whisper_for_conditional_generation_forward(
+ self: WhisperForConditionalGeneration,
+ input_features: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
+ or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
+ only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_features = inputs.input_features
+
+ >>> generated_ids = model.generate(inputs=input_features)
+
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ >>> transcription
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id,
+ self.config.decoder_start_token_id)
+ in_decoder = stage_manager.stage >= decoder_starting_stage
+ at_last_decoder_stage = stage_manager.is_last_stage()
+ outputs = WhisperPipelineForwards.whisper_model_forward(self.model,
+ input_features,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+ if not in_decoder:
+ return outputs
+
+ if not at_last_decoder_stage:
+ # encoder_hidden_states should be passed to the next stage
+ outputs['encoder_hidden_states'] = encoder_hidden_states
+ return outputs
+
+ lm_logits = self.proj_out(outputs[0])
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # move labels to correct device to enable PP
+ labels = labels.to(lm_logits.device)
+ loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ @staticmethod
+ def whisper_for_audio_classification_forward(
+ self: WhisperForAudioClassification,
+ input_features: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_states=None,
+ all_attentions=None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
+ Please refer to original code of transformers for more details.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # audio_classification only holds encoder
+ encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
+ self.encoder,
+ input_features,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage,
+ )
+
+ if not stage_manager.is_last_stage():
+ return encoder_outputs
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = torch.stack(encoder_outputs, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = encoder_outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ pooled_output = hidden_states.mean(dim=1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # move labels to correct device to enable PP
+ labels = labels.to(logits.device)
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + encoder_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index eec339c02872..49613ffb37e0 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -1,5 +1,6 @@
import importlib
from dataclasses import dataclass
+from typing import Optional
import torch.nn as nn
@@ -125,17 +126,33 @@ class PolicyLocation:
# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
- PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"),
+ PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
- PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
+ PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"),
+}
+
+_INFER_POLICY_LIST = {
+ # LlaMa
+ "transformers.models.llama.modeling_llama.LlamaModel":
+ PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
+ "transformers.models.llama.modeling_llama.LlamaForCausalLM":
+ PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
+ # Bloom
+ "transformers.models.bloom.modeling_bloom.BloomModel":
+ PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
+ "transformers.models.bloom.modeling_bloom.BloomForCausalLM":
+ PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
}
-def import_policy(policy_location: PolicyLocation) -> Policy:
+def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
- module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
+ if inference_only:
+ module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
+ else:
+ module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
@@ -151,7 +168,7 @@ def _fullname(obj):
return module + '.' + klass.__qualname__
-def get_autopolicy(model: nn.Module) -> Policy:
+def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
r"""
Return the auto policy for the model
@@ -162,12 +179,15 @@ def get_autopolicy(model: nn.Module) -> Policy:
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
- policy_location = _POLICY_LIST.get(full_name, None)
+ if inference_only:
+ policy_location = _INFER_POLICY_LIST.get(full_name, None)
+ else:
+ policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
- policy = import_policy(policy_location)
+ policy = import_policy(policy_location, inference_only)
return policy()
diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py
index 69493bfb6007..961c6a5259fe 100644
--- a/colossalai/shardformer/policies/base_policy.py
+++ b/colossalai/shardformer/policies/base_policy.py
@@ -11,17 +11,12 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
+from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
-class ParallelModule():
-
- def __init__(self):
- pass
-
-
@dataclass
class SubModuleReplacementDescription:
r"""
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index ace9ada3904f..a141b7bd8fdf 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -10,6 +10,7 @@
from .._utils import getattr_, setattr_
from ..modeling.bert import (
BertPipelineForwards,
+ bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
@@ -47,13 +48,15 @@ def module_policy(self):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertLayer,
+ BertModel,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
policy = {}
-
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
@@ -69,14 +72,26 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
@@ -85,6 +100,7 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
+ kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@@ -93,10 +109,15 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
+ kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@@ -115,6 +136,12 @@ def module_policy(self):
)
])
+ if use_sequence_parallel:
+ self.append_or_create_method_replacement(
+ description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
+ policy=policy,
+ target_key=BertModel)
+
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
@@ -141,20 +168,26 @@ def module_policy(self):
# use flash attention
if self.shard_config.enable_flash_attention:
- policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_bert_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=BertSelfAttention)
# use jit operator
if self.shard_config.enable_jit_fused:
- policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
- policy[BertOutput] = ModulePolicyDescription(method_replacement={
+ },
+ policy=policy,
+ target_key=BertSelfOutput)
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
+ },
+ policy=policy,
+ target_key=BertOutput)
return policy
@@ -205,7 +238,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
- method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
index 50356302e93e..2e5388ab0490 100644
--- a/colossalai/shardformer/policies/blip2.py
+++ b/colossalai/shardformer/policies/blip2.py
@@ -285,34 +285,30 @@ def module_policy(self):
# use flash attention
if self.shard_config.enable_flash_attention:
- policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_blip2_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=Blip2Attention)
# use jit operator
if self.shard_config.enable_jit_fused:
- policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
- method_replacement={
- 'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
- 'dropout_add': get_jit_fused_dropout_add_func(),
- })
- policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=Blip2QFormerSelfOutput)
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
+ },
+ policy=policy,
+ target_key=Blip2QFormerOutput)
return policy
def postprocess(self):
- binding_map = {
- 'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
- }
-
- for k, v in binding_map.items():
- src_mod = getattr_(self.model, k)
- dst_mod = getattr_(self.model, v)
- dst_mod.weight = src_mod.weight
-
return self.model
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index b35764db3870..7c418d02bcb6 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -12,6 +12,7 @@
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
get_bloom_flash_attention_forward,
+ get_bloom_sequence_parallel_forward_fn,
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
@@ -43,6 +44,8 @@ def module_policy(self):
policy = {}
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@@ -53,11 +56,14 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
- ),
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'overlap': overlap
+ }),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
- ),
+ kwargs={'seq_parallel': use_sequence_parallel}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
@@ -65,11 +71,14 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
- ),
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'overlap': overlap
+ }),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
- ),
+ kwargs={'seq_parallel': use_sequence_parallel}),
])
policy[BloomModel] = ModulePolicyDescription(
@@ -116,26 +125,40 @@ def module_policy(self):
policy=policy,
target_key=BloomBlock)
+ if use_sequence_parallel:
+ self.append_or_create_method_replacement(
+ description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)},
+ policy=policy,
+ target_key=BloomModel)
+
if self.shard_config.enable_flash_attention:
- policy[BloomAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_bloom_flash_attention_forward(),
- 'dropout_add': get_dropout_add_func()
- })
+ 'dropout_add': get_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=BloomAttention)
# enable jit fused operator
if self.shard_config.enable_jit_fused:
- policy[BloomAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
- policy[BloomMLP] = ModulePolicyDescription(method_replacement={
+ },
+ policy=policy,
+ target_key=BloomAttention)
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_mlp_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
- policy[BloomGelu] = ModulePolicyDescription(method_replacement={
+ },
+ policy=policy,
+ target_key=BloomMLP)
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_gelu_forward(),
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
- })
+ },
+ policy=policy,
+ target_key=BloomGelu)
return policy
@@ -154,7 +177,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
- method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm2.py
similarity index 67%
rename from colossalai/shardformer/policies/chatglm.py
rename to colossalai/shardformer/policies/chatglm2.py
index e6b458936637..44898847056a 100644
--- a/colossalai/shardformer/policies/chatglm.py
+++ b/colossalai/shardformer/policies/chatglm2.py
@@ -7,7 +7,7 @@
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards
+from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
@@ -15,7 +15,11 @@
GLMBlock,
)
-from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
+from ..modeling.chatglm2 import (
+ get_chatglm_sequence_parallel_forward_fn,
+ get_flash_core_attention_forward,
+ get_jit_fused_glm_block_forward,
+)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -37,6 +41,11 @@ def preprocess(self):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
+ if self.pipeline_stage_manager is not None:
+ # the batch_size_dim is bounded to Model
+ bsz_dim = 1
+ setattr(self.model, 'batch_size_dim', bsz_dim)
+
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
@@ -45,8 +54,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
-
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
@@ -55,36 +65,43 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)
])
- policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={
- "self_attention.num_attention_heads_per_partition":
- self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- "self_attention.projection_size":
- (self.model.config.kv_channels * self.model.config.num_attention_heads) //
- self.shard_config.tensor_parallel_size,
- "self_attention.qkv_hidden_size":
- (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
- self.shard_config.tensor_parallel_size,
- "self_attention.core_attention.num_attention_heads_per_partition":
- self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- "self_attention.core_attention.hidden_size_per_partition":
- self.model.config.kv_channels * self.model.config.num_attention_heads //
- self.shard_config.tensor_parallel_size,
- },
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="self_attention.query_key_value",
- target_module=col_nn.Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="self_attention.dense",
- target_module=col_nn.Linear1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="self_attention.core_attention.attention_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- ])
+ policy[GLMBlock] = ModulePolicyDescription(
+ attribute_replacement={
+ "self_attention.num_attention_heads_per_partition":
+ self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attention.projection_size":
+ (self.model.config.kv_channels * self.model.config.num_attention_heads) //
+ self.shard_config.tensor_parallel_size,
+ "self_attention.qkv_hidden_size":
+ (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
+ self.shard_config.tensor_parallel_size,
+ "self_attention.core_attention.num_attention_heads_per_partition":
+ self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attention.core_attention.hidden_size_per_partition":
+ self.model.config.kv_channels * self.model.config.num_attention_heads //
+ self.shard_config.tensor_parallel_size,
+ },
+ param_replacement=[],
+ sub_module_replacement=[
+ SubModuleReplacementDescription(suffix="self_attention.query_key_value",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'seq_parallel_dim': 0,
+ 'overlap': overlap
+ }),
+ SubModuleReplacementDescription(suffix="self_attention.dense",
+ target_module=col_nn.Linear1D_Row,
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'seq_parallel_dim': 0
+ }),
+ SubModuleReplacementDescription(
+ suffix="self_attention.core_attention.attention_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ ])
+
# optimization configuration
if self.shard_config.enable_fused_normalization:
if not self.model.config.rmsnorm:
@@ -124,16 +141,27 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
# use flash attention
if self.shard_config.enable_flash_attention:
- policy[CoreAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_flash_core_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=CoreAttention)
+
+ # use sequence parallel
+ if use_sequence_parallel:
+ self.append_or_create_method_replacement(
+ description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
+ policy=policy,
+ target_key=ChatGLMModel)
# use jit fused operator
if self.shard_config.enable_jit_fused:
- policy[GLMBlock] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_glm_block_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
+ },
+ policy=policy,
+ target_key=GLMBlock)
return policy
@@ -178,7 +206,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
- method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 20e5fa372c8f..5093fd469af8 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -6,7 +6,7 @@
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
-from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
+from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -37,7 +37,8 @@ def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
-
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
@@ -49,47 +50,55 @@ def module_policy(self):
target_module=col_nn.DropoutForParallelInput,
),
])
- policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
- "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- },
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="attn.c_attn",
- target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={
- "n_fused": 3,
- },
- ),
- SubModuleReplacementDescription(
- suffix="attn.c_proj",
- target_module=col_nn.GPT2FusedLinearConv1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="mlp.c_fc",
- target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={
- "n_fused": 1,
- },
- ),
- SubModuleReplacementDescription(
- suffix="mlp.c_proj",
- target_module=col_nn.GPT2FusedLinearConv1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="attn.attn_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="attn.resid_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="mlp.dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- ])
+
+ policy[GPT2Block] = ModulePolicyDescription(
+ attribute_replacement={
+ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn.c_attn",
+ target_module=col_nn.GPT2FusedLinearConv1D_Col,
+ kwargs={
+ "n_fused": 3,
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
+ ),
+ SubModuleReplacementDescription(suffix="attn.c_proj",
+ target_module=col_nn.GPT2FusedLinearConv1D_Row,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ }),
+ SubModuleReplacementDescription(
+ suffix="mlp.c_fc",
+ target_module=col_nn.GPT2FusedLinearConv1D_Col,
+ kwargs={
+ "n_fused": 1,
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
+ ),
+ SubModuleReplacementDescription(suffix="mlp.c_proj",
+ target_module=col_nn.GPT2FusedLinearConv1D_Row,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ }),
+ SubModuleReplacementDescription(
+ suffix="attn.attn_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.resid_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ ])
# optimization configuration
if self.shard_config.enable_fused_normalization:
@@ -117,9 +126,15 @@ def module_policy(self):
target_key=GPT2Block)
if self.shard_config.enable_flash_attention:
- policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_gpt2_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=GPT2Attention)
+
+ if self.shard_config.enable_sequence_parallelism:
+ policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
+
return policy
def postprocess(self):
@@ -160,7 +175,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
- method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 5ee95f3be8fa..cc131e8168fc 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -1,3 +1,4 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@@ -35,14 +36,22 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
+
if self.shard_config.enable_tensor_parallelism:
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["self_attn.num_key_value_heads"] = \
+ self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
+
policy[LlamaDecoderLayer] = ModulePolicyDescription(
- attribute_replacement={
- "self_attn.hidden_size":
- self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "self_attn.num_heads":
- self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- },
+ attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
@@ -105,9 +114,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
target_key=LlamaModel)
if self.shard_config.enable_flash_attention:
- policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_llama_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=LlamaAttention)
return policy
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index ba6036bd0658..abe491bfaace 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -1,3 +1,4 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List
@@ -39,6 +40,9 @@ def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
@@ -100,16 +104,20 @@ def module_policy(self):
# use flash attention
if self.shard_config.enable_flash_attention:
- policy[OPTAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=OPTAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
- policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
+ },
+ policy=policy,
+ target_key=OPTDecoderLayer)
return policy
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index b1eba0432b49..9753d5a737b9 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -199,12 +199,16 @@ def module_policy(self):
# use flash attention
if self.shard_config.enable_flash_attention:
- policy[SamAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_sam_flash_attention_forward(),
- })
- policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
+ },
+ policy=policy,
+ target_key=SamAttention)
+ self.append_or_create_method_replacement(description={
'forward': get_sam_vision_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=SamVisionAttention)
return policy
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 2ef52c214c6b..92cbd3f72b83 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -1,6 +1,8 @@
+import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
+import numpy as np
from torch import Tensor, nn
from colossalai.shardformer.layer import (
@@ -58,6 +60,10 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
@@ -178,24 +184,33 @@ def module_policy(self):
# use flash attention
if self.shard_config.enable_flash_attention:
- policy[T5Attention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_t5_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=T5Attention)
# use jit operator
if self.shard_config.enable_jit_fused:
- policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
- policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
+ },
+ policy=policy,
+ target_key=T5LayerFF)
+ self.append_or_create_method_replacement(description={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
- policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
+ },
+ policy=policy,
+ target_key=T5LayerSelfAttention)
+ self.append_or_create_method_replacement(description={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
+ },
+ policy=policy,
+ target_key=T5LayerCrossAttention)
+
return policy
def postprocess(self):
@@ -228,13 +243,7 @@ def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
- num_encoder_stages = 0
- optimal_diff = 2**31 - 1
- for i in range(1, num_stages):
- attempt = objective(i)
- if attempt < optimal_diff:
- num_encoder_stages = i
- optimal_diff = attempt
+ num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index 617720ee7950..b4fb8692e684 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Callable, Dict, List, Union
import torch.nn as nn
@@ -32,6 +33,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@@ -90,16 +95,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
# use flash attention
if self.shard_config.enable_flash_attention:
- policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_vit_flash_self_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=ViTSelfAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
- policy[ViTOutput] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_vit_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
+ },
+ policy=policy,
+ target_key=ViTOutput)
return policy
def new_model_class(self):
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
index 2ac7a49fd27b..31ba82166b31 100644
--- a/colossalai/shardformer/policies/whisper.py
+++ b/colossalai/shardformer/policies/whisper.py
@@ -1,10 +1,17 @@
+import warnings
+from functools import partial
+from typing import Callable, Dict, List, Tuple
+
+import numpy as np
import torch.nn as nn
+from torch import Tensor
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
+ WhisperPipelineForwards,
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward,
@@ -12,7 +19,8 @@
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
- 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
+ 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
+ 'WhisperForAudioClassificationPolicy'
]
@@ -26,7 +34,6 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
- # TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
@@ -45,6 +52,16 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn(
+ "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
+ #TODO using the jit fused add_and_dropout affect the accuracy
+ if self.shard_config.enable_jit_fused:
+ self.shard_config.enable_jit_fused = False
+ warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
+
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
@@ -191,20 +208,26 @@ def module_policy(self):
# enable flash attention
if self.shard_config.enable_flash_attention:
- policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_whisper_flash_attention_forward(),
- })
+ },
+ policy=policy,
+ target_key=WhisperAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
- policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
- 'forward': get_jit_fused_whisper_encoder_layer_forward(),
- 'dropout_add': get_jit_fused_dropout_add_func(),
- })
- policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
+ self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
- })
+ },
+ policy=policy,
+ target_key=WhisperDecoderLayer)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_whisper_encoder_layer_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=WhisperEncoderLayer)
return policy
@@ -223,6 +246,146 @@ def add_lm_head_policy(self, base_policy):
def postprocess(self):
return self.model
+ @staticmethod
+ def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
+ num_stages: int) -> Tuple[List[int], int]:
+ """
+ Distribute whisper layers into stages when pipeline parallel is used.
+ Return the layer distribution as a list and the starting stage of decoder.
+ If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
+ """
+
+ # number of encoder layers must be a positive integer
+ if num_encoder_layers <= 0:
+ raise ValueError("The number of encoder layers for whisper must be a positive integer.")
+
+ # number of layers should be large enough to fill in every stage
+ if num_encoder_layers + num_decoder_layers < num_stages:
+ raise ValueError("The total number of layers can't be smaller than number of stages.")
+
+ # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
+ if num_decoder_layers == 0:
+ return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
+
+ # the number of stages distributed between encoder and decoder is optmized in this way:
+ # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
+ # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
+ def objective(num_encoder_stages):
+ return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
+
+ num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
+ num_decoder_stages = num_stages - num_encoder_stages
+
+ encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
+ decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
+ return encoder_distribution + decoder_distribution, num_encoder_stages
+
+ @staticmethod
+ def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
+ decoder_starting_stage: int) -> Tuple[bool, int, int]:
+ """
+ Input the distribution of layers among stages, the current stage and the first stage of decoder.
+ Return the starting/ending idx of layers in encoder/decoder
+ """
+ if stage < decoder_starting_stage:
+ return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
+ else:
+ return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
+
+ def get_held_layers(self) -> List[nn.Module]:
+
+ assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
+ stage_manager = self.pipeline_stage_manager
+
+ if self.model.__class__.__name__ == 'WhisperModel':
+ model = self.model
+ elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
+ model = self.model.model
+ else:
+ model = None
+
+ if model:
+ encoder = self.model.get_encoder()
+ decoder = self.model.get_decoder()
+ else:
+ # whisper for audio classification holds encoder only
+ encoder = self.model.encoder
+ decoder = None
+
+ num_encoder_layers = len(encoder.layers)
+ if decoder:
+ num_decoder_layers = len(decoder.layers)
+ else:
+ num_decoder_layers = 0
+
+ held_layers = []
+ layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
+ start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
+ decoder_starting_stage)
+
+ if stage_manager.stage < decoder_starting_stage:
+ # current stage is in whisper's encoder
+ if stage_manager.is_first_stage():
+ held_layers.append(encoder.embed_positions)
+ held_layers.append(encoder.conv1)
+ held_layers.append(encoder.conv2)
+ if stage_manager.stage == decoder_starting_stage - 1:
+ held_layers.append(encoder.layer_norm)
+ held_layers.extend(encoder.layers[start_idx:end_idx])
+ else:
+ # current stage is in whisper's decoder
+ # TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
+ # the case encoder and decoder put in same stage should be add in the future.
+ if stage_manager.stage == decoder_starting_stage:
+ held_layers.append(decoder.embed_tokens)
+ held_layers.append(decoder.embed_positions)
+ if stage_manager.is_last_stage():
+ held_layers.append(decoder.layer_norm)
+ held_layers.extend(decoder.layers[start_idx:end_idx])
+ return held_layers
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if not self.pipeline_stage_manager:
+ raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
+ stage_manager = self.pipeline_stage_manager
+
+ if self.model.__class__.__name__ == 'WhisperModel':
+ model = self.model
+ elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
+ model = self.model.model
+ else:
+ model = None
+
+ if model:
+ encoder = self.model.get_encoder()
+ decoder = self.model.get_decoder()
+ else:
+ encoder = self.model.encoder
+ decoder = None
+
+ num_encoder_layers = len(encoder.layers)
+ if decoder:
+ num_decoder_layers = len(decoder.layers)
+ else:
+ num_decoder_layers = 0
+
+ layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
+ stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
+ decoder_starting_stage)
+
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+ }
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
+
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
@@ -230,6 +393,24 @@ class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
+ def module_policy(self):
+ from transformers import WhisperModel
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=WhisperModel,
+ new_forward=WhisperPipelineForwards.whisper_model_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ return super().get_held_layers()
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ "no shared params in whisper model"
+ return []
+
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
@@ -238,20 +419,82 @@ def __init__(self) -> None:
super().__init__()
def module_policy(self):
- module_policy = super().module_policy()
- module_policy = self.add_lm_head_policy(module_policy)
- return module_policy
+ from transformers import WhisperForConditionalGeneration
+ policy = super().module_policy()
+ policy = self.add_lm_head_policy(policy)
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
+ new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
+ policy=policy)
+ return policy
def postprocess(self):
- binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"}
- for k, v in binding_map.items():
- param = getattr_(self.model, k)
- setattr_(self.model, v, param)
return self.model
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.proj_out)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ module = self.model
+ model = module.model
+
+ if model:
+ encoder = self.model.get_encoder()
+ decoder = self.model.get_decoder()
+ else:
+ encoder = self.model.encoder
+ decoder = None
+
+ num_encoder_layers = len(encoder.layers)
+ if decoder:
+ num_decoder_layers = len(decoder.layers)
+ else:
+ num_decoder_layers = 0
+
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager is not None and stage_manager.num_stages > 1:
+ _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
+ stage_manager.num_stages)
+ shared_params = []
+ shared_embedding = {}
+ if id(module.proj_out) == id(model.decoder.embed_tokens):
+ shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens
+ shared_embedding[stage_manager.num_stages - 1] = module.proj_out
+ if len(shared_embedding) > 0:
+ shared_params.append(shared_embedding)
+ return shared_params
+ return []
+
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
+
+ def preprocess(self):
+ return self.model
+
+ def module_policy(self):
+ from transformers import WhisperForAudioClassification
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
+ new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.projector)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ return []
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 0c28f115d018..0b6e1640952b 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -15,31 +15,39 @@ class ShardConfig:
The config for sharding the huggingface model
Args:
- tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group.
- pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline.
- enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
- enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
- enable_all_optimization (bool): Whether to turn on all optimization, default is False.
+ tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group.
+ pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
+ enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True.
+ enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
+ enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
+ inference_only (bool): Whether only doing forward passing. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
- enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
-
- # pipeline_parallel_size: int
- # data_parallel_size: int
+ enable_sequence_parallelism: bool = False
+ enable_sequence_overlap: bool = False
+ enable_all_optimization: bool = False
+ inference_only: bool = False
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
- # inference_only: bool = True
- # gather_output: bool = True
@property
def tensor_parallel_size(self):
return self._tensor_parallel_size
def __post_init__(self):
+ if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
+ raise ValueError(
+ "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
+ if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
+ raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
@@ -57,3 +65,11 @@ def _turn_on_all_optimization(self):
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
+ self.enable_sequence_parallelism = True
+ self.enable_sequence_overlap = True
+
+ def _infer(self):
+ """
+ Set default params for inference.
+ """
+ assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py
index 0ed745a1fc4a..7592069a2dd9 100644
--- a/colossalai/shardformer/shard/sharder.py
+++ b/colossalai/shardformer/shard/sharder.py
@@ -27,7 +27,7 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
- self.policy = get_autopolicy(self.model) if policy is None else policy
+ self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
self.shard_config = shard_config
def shard(self) -> List[Dict[int, Tensor]]:
@@ -92,22 +92,21 @@ def _recursive_replace_layer(
param_replacement (List[Callable]): The function list to get parameter shard information in policy
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
+ include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
- # released layers are not shardable
- can_replace_param_or_layer = include is None or module in include
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
- if param_replacement is not None and can_replace_param_or_layer:
+ if param_replacement is not None and (include is None or module in include):
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
- if sub_module_replacement is not None and can_replace_param_or_layer:
- self._replace_sub_module(module, sub_module_replacement)
+ if sub_module_replacement is not None:
+ self._replace_sub_module(module, sub_module_replacement, include)
for name, child in module.named_children():
self._recursive_replace_layer(child,
@@ -154,18 +153,17 @@ def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Calla
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
- def _replace_sub_module(
- self,
- org_layer: nn.Module,
- sub_module_replacement: List[SubModuleReplacementDescription],
- ) -> None:
+ def _replace_sub_module(self,
+ org_layer: nn.Module,
+ sub_module_replacement: List[SubModuleReplacementDescription],
+ include: Optional[Set[nn.Module]] = None) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
-
+ include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
for description in sub_module_replacement:
suffix = description.suffix
@@ -174,9 +172,12 @@ def _replace_sub_module(
assert target_module is not None, 'target_module should not be None'
- # TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix, ignore=True)
+ # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
+ if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
+ continue
+
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"
diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py
index 79b2e3ef936a..6158d0bfe2ad 100644
--- a/colossalai/tensor/d_tensor/comm_spec.py
+++ b/colossalai/tensor/d_tensor/comm_spec.py
@@ -28,7 +28,7 @@ class CommSpec:
to determine the buffer shape, and logical_process_axis
Argument:
- comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
+ comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py
index c968050de49d..4740a316b7f5 100644
--- a/colossalai/tensor/dist_spec_mgr.py
+++ b/colossalai/tensor/dist_spec_mgr.py
@@ -2,7 +2,6 @@
import torch
import torch.distributed as dist
-# from colossalai.nn.layer.utils import divide
from numpy import prod
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py
index 99d782c3f6e8..b837333a2388 100644
--- a/colossalai/tensor/shape_consistency.py
+++ b/colossalai/tensor/shape_consistency.py
@@ -339,7 +339,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
RS01 -> RR
'''
valid_spec_dict = {}
- comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
+ comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
@@ -362,7 +362,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
b_target_pair = (b_index, [])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
- comm_spec = CommSpec(comm_pathern,
+ comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index 7b2e8480c66c..6f9717d353e6 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -1,12 +1,14 @@
from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (
+ _cast_float,
clip_grad_norm_fp32,
conditional_context,
copy_tensor_parallel_attributes,
count_zeros_fp32,
disposable,
ensure_path_exists,
+ free_storage,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
@@ -72,4 +74,6 @@
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
+ '_cast_float',
+ 'free_storage',
]
diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py
index 8022e84dc24b..998901708239 100644
--- a/colossalai/utils/common.py
+++ b/colossalai/utils/common.py
@@ -470,3 +470,22 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
+
+
+def free_storage(data: torch.Tensor) -> None:
+ """Free underlying storage of a Tensor."""
+ if data.storage().size() > 0:
+ # Since we're modifying the Tensor's Storage directly, make sure the Tensor
+ # is the sole occupant of the Storage.
+ assert data.storage_offset() == 0
+ data.storage().resize_(0)
+
+
+def _cast_float(args, dtype: torch.dtype):
+ if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
+ args = args.to(dtype)
+ elif isinstance(args, (list, tuple)):
+ args = type(args)(_cast_float(t, dtype) for t in args)
+ elif isinstance(args, dict):
+ args = {k: _cast_float(v, dtype) for k, v in args.items()}
+ return args
diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py
index 2318e07a7f8d..881ddde78648 100644
--- a/colossalai/utils/data_sampler/data_parallel_sampler.py
+++ b/colossalai/utils/data_sampler/data_parallel_sampler.py
@@ -4,20 +4,18 @@
import math
import random
-import numpy as np
-from typing import TypeVar, Iterator
+from typing import Iterator, TypeVar
+import numpy as np
import torch
-from torch.utils.data import Sampler, Dataset, DataLoader
+from torch.utils.data import DataLoader, Dataset, Sampler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import DATA_SAMPLERS
T_co = TypeVar('T_co', covariant=True)
-@DATA_SAMPLERS.register_module
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism.
@@ -30,11 +28,7 @@ class DataParallelSampler(Sampler):
the batch size, then the last batch will be smaller, defaults to False.
"""
- def __init__(self,
- dataset: Dataset,
- shuffle: bool = False,
- seed: int = 0,
- drop_last: bool = False) -> None:
+ def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None:
self.dataset = dataset
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
self.rank = gpc.get_local_rank(ParallelMode.DATA)
@@ -54,8 +48,7 @@ def __init__(self,
self.num_replicas # type: ignore[arg-type]
)
else:
- self.num_samples = math.ceil(
- len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@@ -72,7 +65,7 @@ def __iter__(self) -> Iterator[T_co]:
# set_epoch manually
self.epoch += 1
else:
- indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
@@ -80,8 +73,7 @@ def __iter__(self) -> Iterator[T_co]:
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
- indices += (indices * math.ceil(padding_size /
- len(indices)))[:padding_size]
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
@@ -109,8 +101,8 @@ def set_epoch(self, epoch: int) -> None:
def get_dataloader(dataset,
shuffle=False,
- seed=1024,
- add_sampler=True,
+ seed=1024,
+ add_sampler=True,
drop_last=False,
pin_memory=False,
num_workers=0,
diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/utils/profiler/profiler.py
index 8f43a0b96de0..3026d723deb0 100644
--- a/colossalai/utils/profiler/profiler.py
+++ b/colossalai/utils/profiler/profiler.py
@@ -1,17 +1,17 @@
-import os
-from typing import List
-from colossalai.engine import Engine
-from torch.profiler import profile as torch_profile
-from torch.profiler.profiler import ProfilerAction
-from typing import Any, Callable, Iterable, Optional
-from torch.autograd import ProfilerActivity
+import gzip
import json
import os
import tempfile
-import gzip
+from typing import Any, Callable, Iterable, List, Optional
+
+from torch.autograd import ProfilerActivity
+from torch.profiler import profile as torch_profile
+from torch.profiler.profiler import ProfilerAction
+
+from colossalai.legacy.engine import Engine
+from colossalai.logging import get_dist_logger
from colossalai.utils.profiler.extention import ProfilerExtension
from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
-from colossalai.logging import get_dist_logger
class profile(torch_profile):
diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py
index 127055c8c1ef..412bd7277eee 100644
--- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py
+++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py
@@ -1,12 +1,14 @@
import os
import threading
import time
-import torch
from enum import Enum
from typing import List
-from colossalai.gemini.stateful_tensor import StatefulTensor
+
+import torch
+
from colossalai.gemini.ophooks import BaseOpHook
-from colossalai.engine import Engine
+from colossalai.gemini.stateful_tensor import StatefulTensor
+from colossalai.legacy.engine import Engine
from colossalai.utils.profiler.extention import ProfilerExtension
diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py
index 75f8576ca477..dad852a34a71 100644
--- a/colossalai/zero/gemini/colo_init_context.py
+++ b/colossalai/zero/gemini/colo_init_context.py
@@ -87,7 +87,7 @@ def __init__(self,
self._default_dist_spec = default_dist_spec
def _register_colo_modules(self):
- from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
+ from colossalai.legacy.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding())
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 0cd90459b76a..918b08cd3150 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -10,14 +10,13 @@
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
-from colossalai.checkpoint_io.utils import calculate_tensor_size
+from colossalai.checkpoint_io.utils import StateDictSharder, calculate_tensor_size
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
-from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
-from colossalai.utils import get_current_device, is_ddp_ignored
+from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@@ -733,7 +732,7 @@ def state_dict_shard(self,
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
- sharder = _StateDictSharder(max_shard_size)
+ sharder = StateDictSharder(max_shard_size)
# get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict()
@@ -755,7 +754,7 @@ def state_dict_shard(self,
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
- block, block_size = sharder.append(prefix + name, gathered_param)
+ block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
@@ -766,7 +765,7 @@ def state_dict_shard(self,
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
- block, block_size = sharder.append(prefix + name, buffer)
+ block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# save extra states
@@ -774,32 +773,8 @@ def state_dict_shard(self,
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
- block, block_size = sharder.append(extra_state_key, extra_state)
+ block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
yield sharder.current_block, sharder.current_block_size
-
-
-class _StateDictSharder:
-
- def __init__(self, max_shard_size: int) -> None:
- self.max_shard_size = max_shard_size
- self.current_block = OrderedDict()
- self.current_block_size = 0
-
- def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
- tensor_size = calculate_tensor_size(tensor)
- ret_block = None
- ret_block_size = 0
-
- # before we return the current block and create a new block,
- # we need to ensure that the current block is not empty
- if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
- ret_block = self.current_block
- ret_block_size = self.current_block_size
- self.current_block = OrderedDict()
- self.current_block_size = 0
- self.current_block[name] = tensor
- self.current_block_size += tensor_size
- return ret_block, ret_block_size
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 175b97647e16..0c593deff225 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -10,7 +10,7 @@
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
-from colossalai.checkpoint_io.utils import calculate_tensor_size
+from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
@@ -692,49 +692,17 @@ def state_shard(self,
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
- current_block = {}
- current_block_size = 0
-
+ sharder = StateDictSharder(max_shard_size)
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
- ret_block = None
- ret_block_size = 0
-
- # A state might contain more than one tensors.
- # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
- state_size = 0
- isDTensor = False
- for state_tensor in state.values():
-
- # When state_tensor is not of Tensor class,
- # e.g., a SGD optimizer with momentum set to 0 can have None as state
- # The calculation of tensor size should be skipped to avoid error.
- if not isinstance(state_tensor, torch.Tensor):
- continue
-
- # If the states are stored as DTensors, mark isDTensor as true.
- if is_distributed_tensor(state_tensor):
- isDTensor = True
- state_size += calculate_tensor_size(state_tensor)
-
- if not isDTensor:
-
- if current_block_size + state_size > max_shard_size and current_block_size > 0:
- ret_block = current_block
- ret_block_size = current_block_size
- current_block = {}
- current_block_size = 0
-
- current_block[param_id] = state
- current_block_size += state_size
-
- if ret_block != None:
- yield ret_block, ret_block_size
+ block, block_size = sharder.append_optim_state(param_id, state)
+ if block is not None:
+ yield block, block_size
- yield current_block, current_block_size
+ yield sharder.current_block, sharder.current_block_size
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
index 0c9eac8b63e3..e5466965cc48 100644
--- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
+++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
@@ -1,7 +1,7 @@
import torch.nn
-from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
+from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
index 8f8fec64924e..d68a9dc6458f 100644
--- a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
+++ b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
@@ -1,6 +1,6 @@
import torch
-from colossalai.registry import OPHOOKS
+from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
index a2a62fb9788a..6b76a2116a49 100644
--- a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
+++ b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
@@ -1,6 +1,6 @@
import torch
-from colossalai.registry import OPHOOKS
+from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook
diff --git a/colossalai/zero/legacy/sharded_model/zero_hook.py b/colossalai/zero/legacy/sharded_model/zero_hook.py
index 50f4bdfc775d..1815bee3a9e0 100644
--- a/colossalai/zero/legacy/sharded_model/zero_hook.py
+++ b/colossalai/zero/legacy/sharded_model/zero_hook.py
@@ -3,8 +3,8 @@
import torch
import torch.distributed as dist
+from colossalai.legacy.registry import OPHOOKS
from colossalai.logging import get_dist_logger
-from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 96d5902e893f..0bdd6a3e2370 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -6,6 +6,7 @@
import torch
import torch.distributed as dist
+import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
@@ -307,7 +308,7 @@ def _add_to_bucket(self, param, group_id):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
- group_id != self._bucket_store.current_group_id:
+ group_id != self._bucket_store.current_group_id:
self._run_reduction()
padding_size = self._param_store.get_param_padding_size(param)
@@ -337,6 +338,24 @@ def backward(self, loss, retain_graph=False):
self.zero_grad()
+ def backward_by_grad(self, tensor, grad):
+ assert not(self._partition_grads and not self.require_grad_sync), \
+ "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
+
+ if self.mixed_precision_mixin is not None:
+ grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
+ torch.autograd.backward(tensor, grad)
+
+ if not self.require_grad_sync:
+ return
+ self._reduce_grad(self._partition_grads)
+
+ # clear reduced grads
+ if self._overlap_communication:
+ torch.cuda.synchronize()
+
+ self.zero_grad()
+
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
@@ -362,7 +381,6 @@ def zero_grad(self, set_to_none=True):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
-
if not self.require_grad_sync:
return
@@ -553,11 +571,9 @@ def load_state_dict(self, state_dict: Dict):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
- device = 'cpu' if self._cpu_offload else 'cuda'
- zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
+ zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
- zero_state_dict = dict()
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
@@ -602,3 +618,19 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i
ret_block_size += current_block_size
yield ret_block, ret_block_size
+
+ def update_master_params(self, model: nn.Module) -> None:
+ """Update master params from working params
+
+ Args:
+ model (nn.Module): The model to update master params
+ """
+ for p in model.parameters():
+ p_id = id(p)
+ if p_id in self._param_store.working_to_master_param:
+ master_param = self._param_store.working_to_master_param[p_id]
+ padding_size = self._param_store.get_param_padding_size(p)
+ working_param = p.data.view(-1)
+ if padding_size > 0:
+ working_param = torch.nn.functional.pad(working_param, [0, padding_size])
+ master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 945ca4080413..bb5f49bc546b 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
@@ -49,7 +50,7 @@
-
并行训练样例展示
- - LLaMA
+ - LLaMA 1/2
- GPT-3
- GPT-2
- BERT
@@ -210,7 +211,16 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
(返回顶端)
## 并行训练样例展示
-### LLaMA
+### LLaMA2
+
+
+
+
+- 700亿参数LLaMA2训练加速195%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
+
+### LLaMA1
@@ -443,7 +453,7 @@ Colossal-AI项目受一些相关的项目启发而成立,一些项目是我们
}
```
-Colossal-AI 已被[NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+Colossal-AI 已被[NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,等顶级会议录取为官方教程。
(返回顶端)
diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md
index 1caf58c8734e..384221596885 100644
--- a/docs/source/en/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/en/advanced_tutorials/add_your_parallel.md
@@ -92,14 +92,14 @@ follow the steps below to create a new distributed initialization.
Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce
strategies may be executed for different kinds of parallelism, users can
-inherit `colossalai.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
+inherit `colossalai.legacy.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data
parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own
gradient handler like below:
```python
-from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.registry import GRADIENT_HANDLER
+from colossalai.legacy.engine import BaseGradientHandler
@GRADIENT_HANDLER.register_module
class YourGradientHandler(BaseGradientHandler):
@@ -121,4 +121,5 @@ gradient_handlers = [
Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
schedules. If you want to modify how the forward and backward passes are executed, you can
-inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
+inherit `colossalai.legacy.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
+
diff --git a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
index 281fd47554ca..0a94a7f5d691 100644
--- a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
+++ b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
@@ -176,7 +176,7 @@ In our latest example, a Gemini + ZeRO DDP model is also defined to reduce overh
```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
- from colossalai.nn.parallel import GeminiDDP
+ from colossalai.zero import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 715c15eb6300..36c94fb492cd 100644
--- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -36,14 +36,14 @@ import torch
import torch.nn as nn
from colossalai import nn as col_nn
from colossalai.amp import AMP_TYPE
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils.timer import MultiTimer
from model_zoo.gpt import GPTLMLoss
from torch.nn import functional as F
@@ -268,3 +268,4 @@ def train():
return_output_label=False,
)
```
+
diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index 6adfe4f113da..6dbe338008fa 100644
--- a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -34,11 +34,11 @@ import colossalai
import colossalai.nn as col_nn
import torch
import torch.nn as nn
-from colossalai.builder import build_pipeline_model
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.builder import build_pipeline_model
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from timm.models import vision_transformer as vit
from torchvision import transforms
@@ -51,17 +51,17 @@ from torchvision.datasets import CIFAR10
Generally, we provide 3 ways to build a pipelined model:
-1. `colossalai.builder.build_pipeline_model_from_cfg`
-2. `colossalai.builder.build_pipeline_model`
+1. `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+2. `colossalai.legacy.builder.build_pipeline_model`
3. Split the model by stages by yourself
When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU.
-`colossalai.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size).
+`colossalai.legacy.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size).
-If you are familiar with `PyTorch`, you can use `colossalai.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly.
+If you are familiar with `PyTorch`, you can use `colossalai.legacy.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly.
-In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.builder.build_pipeline_model()` to build the pipelined model.
+In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.legacy.builder.build_pipeline_model()` to build the pipelined model.
When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`.
@@ -245,3 +245,4 @@ def train():
hooks=hook_list,
display_progress=True)
```
+
diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index a2deaeb88893..0ec9d5c3c5de 100644
--- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -78,8 +78,8 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import LinearWarmupLR
-from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.nn.metric import Accuracy
+from colossalai.legacy.trainer import Trainer, hooks
```
- Other modules
@@ -273,8 +273,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
### Build pipeline model (`/hybrid_parallel/model/vit.py`)
Colossal-AI provides two methods to build a pipeline model from the existing model.
-- `colossalai.builder.build_pipeline_model_from_cfg`
-- `colossalai.builder.build_pipeline_model`
+- `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+- `colossalai.legacy.builder.build_pipeline_model`
Besides, you can also build a pipeline model from scratch with Colossal-AI.
```python
@@ -284,11 +284,11 @@ from typing import Callable
import inspect
import torch
from colossalai import nn as col_nn
-from colossalai.registry import LAYERS, MODELS
+from colossalai.legacy.registry import LAYERS, MODELS
from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
from torch import dtype, nn
from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
@@ -415,7 +415,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
#### Import modules
```python
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.utils import MultiTimer
import os
@@ -644,3 +644,4 @@ torchrun --standalone --nproc_per_node train_hybrid.py --config ./co
# If your torch >= 1.9.0
# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py
```
+
diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md
index 1e75c343c14f..392251ef06b2 100644
--- a/docs/source/en/basics/booster_api.md
+++ b/docs/source/en/basics/booster_api.md
@@ -1,6 +1,6 @@
# Booster API
-Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https://github.com/CjhHa1)
+Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003)
**Prerequisite:**
@@ -9,32 +9,36 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https:/
**Example Code**
-- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
+- [Train ResNet on CIFAR-10 with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
+- [Train LLaMA-1/2 on RedPajama with Booster](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
## Introduction
-In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, I will cover how `colossalai.booster` works and what we should take note of.
+In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also, calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, we will cover how `colossalai.booster` works and what we should take note of.
### Plugin
Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:
+**_HybridParallelPlugin:_** This plugin wraps the hybrid parallel training acceleration solution. It provides an interface for any combination of tensor parallel, pipeline parallel and data parallel strategies including DDP and ZeRO.
+
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
-**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
+**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallel at the module level which can run across multiple machines.
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
-
**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
+More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md).
+
### API of booster
{{ autodoc:colossalai.booster.Booster }}
## Usage
-In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `colossalai.booster` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
+In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `booster.boost` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
A pseudo-code example is like below:
@@ -48,15 +52,21 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
+ # launch colossalai
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+
+ # create plugin and objects for training
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
+
+ # use booster.boost to wrap the training objects
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
+ # do training as normal, except that the backward should be called by booster
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
@@ -65,14 +75,16 @@ def train():
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
+ optimizer.zero_grad()
+ # checkpointing using booster api
save_path = "./model"
- booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
+ booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
new_model = resnet18()
booster.load_model(new_model, save_path)
```
-[more design details](https://github.com/hpcaitech/ColossalAI/discussions/3046)
+For more design details please see [this page](https://github.com/hpcaitech/ColossalAI/discussions/3046).
diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md
index b2840fe87441..ea6c11ae2cdc 100644
--- a/docs/source/en/basics/booster_checkpoint.md
+++ b/docs/source/en/basics/booster_checkpoint.md
@@ -13,12 +13,36 @@ We've introduced the [Booster API](./booster_api.md) in the previous tutorial. I
{{ autodoc:colossalai.booster.Booster.save_model }}
-Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers).
+Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers), so you can use huggingface `from_pretrained` method to load model from our sharded checkpoint.
{{ autodoc:colossalai.booster.Booster.load_model }}
Model must be boosted by `colossalai.booster.Booster` before loading. It will detect the checkpoint format automatically, and load in corresponding way.
+If you want to load a pretrained model from Huggingface while the model is too large to be directly loaded through `from_pretrained` on a single device, a recommended way is to download the pretrained weights to a local directory, and use `booster.load` to load from that directory after boosting the model. Also, the model should be initialized under lazy initialization context to avoid OOM. Here is an example pseudocode:
+```python
+from colossalai.lazy import LazyInitContext
+from huggingface_hub import snapshot_download
+...
+
+# Initialize model under lazy init context
+init_ctx = LazyInitContext(default_device=get_current_device)
+with init_ctx:
+ model = LlamaForCausalLM(config)
+
+...
+
+# Wrap the model through Booster.boost
+model, optimizer, _, _, _ = booster.boost(model, optimizer)
+
+# download huggingface pretrained model to local directory.
+model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp")
+
+# load model using booster.load
+booster.load(model, model_dir)
+...
+```
+
## Optimizer Checkpoint
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md
index c5c45abce8f7..d7532b0ce39b 100644
--- a/docs/source/en/basics/booster_plugins.md
+++ b/docs/source/en/basics/booster_plugins.md
@@ -1,6 +1,6 @@
# Booster Plugins
-Author: [Hongxin Liu](https://github.com/ver217)
+Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003)
**Prerequisite:**
- [Booster API](./booster_api.md)
@@ -15,6 +15,7 @@ We currently provide the following plugins:
- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management.
- [Torch DDP Plugin](#torch-ddp-plugin): It is a wrapper of `torch.nn.parallel.DistributedDataParallel` and can be used to train models with data parallelism.
- [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp.
+- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below.
More plugins are coming soon.
@@ -43,8 +44,6 @@ We've tested compatibility on some famous models, following models may not be su
Compatibility problems will be fixed in the future.
-> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
-
### Gemini Plugin
This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](../features/zero_with_chunk.md).
@@ -69,4 +68,24 @@ More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.h
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
+
+### Hybrid Parallel Plugin
+
+This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts:
+
+1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md).
+
+2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md).
+
+3. Torch DDP: This plugin will automatically adopt Pytorch DDP as data parallel strategy when pipeline parallel and Zero is not used. More details about its arguments configuration can be found in [Pytorch DDP Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
+
+4. Zero: This plugin can adopt Zero 1/2 as data parallel strategy through setting the `zero_stage` argument as 1 or 2 when initializing plugin. Zero 1 is compatible with pipeline parallel strategy, while Zero 2 is not. More details about its argument configuration can be found in [Low Level Zero Plugin](#low-level-zero-plugin).
+
+> ⚠ When using this plugin, only the subset of Huggingface transformers supported by Shardformer are compatible with tensor parallel, pipeline parallel and optimization tools. Mainstream transformers such as Llama 1, Llama 2, OPT, Bloom, Bert and GPT2 etc. are all supported by Shardformer.
+
+> ⚠ This plugin only supports sharded checkpointing methods for model/optimizer at present. Unsharded checkpointing methods will be supported in future release.
+
+{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
+
+
diff --git a/docs/source/en/basics/engine_trainer.md b/docs/source/en/basics/engine_trainer.md
index d2f99563f042..e17c37e24a55 100644
--- a/docs/source/en/basics/engine_trainer.md
+++ b/docs/source/en/basics/engine_trainer.md
@@ -64,7 +64,7 @@ Trainer is a more high-level wrapper for the user to execute training with fewer
```python
from colossalai.logging import get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
# build components and initialize with colossalai.initialize
...
@@ -107,7 +107,7 @@ If you want to customize your own hook class, you can inherit `hooks.BaseHook` a
```python
from colossalai.logging import get_dist_logger
-from colossalai.trainer import hooks
+from colossalai.legacy.trainer import hooks
class LogMessageHook(hooks.BaseHook):
@@ -344,8 +344,8 @@ for epoch in range(gpc.config.NUM_EPOCHS):
If you wish to train with a trainer object, you can follow the code snippet below:
```python
-from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.nn.metric import Accuracy
+from colossalai.legacy.trainer import Trainer, hooks
# create a trainer object
@@ -387,3 +387,4 @@ python -m torch.distributed.launch --nproc_per_node --master_addr loc
# with trainer
python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
```
+
diff --git a/docs/source/en/basics/model_checkpoint.md b/docs/source/en/basics/model_checkpoint.md
index 70334f1c41e7..c3ba5b04bca2 100644
--- a/docs/source/en/basics/model_checkpoint.md
+++ b/docs/source/en/basics/model_checkpoint.md
@@ -41,7 +41,7 @@ for epoch in range(num_epochs):
#### Save when using trainer
```python
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...)
@@ -61,3 +61,4 @@ model = ...
load_checkpoint('xxx.pt', model)
... # train or test
```
+
diff --git a/docs/source/en/features/1D_tensor_parallel.md b/docs/source/en/features/1D_tensor_parallel.md
index 7157af210bc5..0f01cfd325e5 100644
--- a/docs/source/en/features/1D_tensor_parallel.md
+++ b/docs/source/en/features/1D_tensor_parallel.md
@@ -7,7 +7,7 @@ Author: Zhengda Bian, Yongbin Li
- [Configure Parallelization](../basics/configure_parallelization.md)
**Example Code**
-- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)
+- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples)
**Related Paper**
- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf)
@@ -42,77 +42,7 @@ Given $P$ processors, we present the theoretical computation and memory cost, as
## Usage
-To enable 1D tensor parallelism for our model, e.g. on 2 GPUs, we need to configure the parallelism setting as below.
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=2, mode='1d'),
-))
-```
-Then Colossal-AI will automatically apply 1D parallelism to all the layers from `colossalai.nn`.
+1D tensor parallelism is implemented by `Shardformer` feature in the newest version of ColossalAI.
+For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
-Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-
-Launch Colossal-AI on 2 GPUs and build the model.
-
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-
-m = MLP()
-```
-We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
-```shell
-Weight of the first linear layer: torch.Size([256, 512])
-Weight of the second linear layer: torch.Size([512, 256])
-```
-The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the column-parallel partitioning, it becomes `[256, 512]`.
-Similarly, the second row-parallel layer partitions the weight `[1024, 256]` into `[512, 256]`.
-
-We can run the model with some random inputs.
-```python
-from colossalai.utils import get_current_device
-
-x = torch.randn((16, 256), device=get_current_device())
-torch.distributed.broadcast(x, src=0) # synchronize input
-
-x = m(x)
-```
-Then we can see the shapes of activation results.
-```shell
-Output of the first linear layer: torch.Size([16, 512])
-Output of the second linear layer: torch.Size([16, 256])
-```
-The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs.
+
diff --git a/docs/source/en/features/2D_tensor_parallel.md b/docs/source/en/features/2D_tensor_parallel.md
index aae8cc9eef97..c79e7d196f8b 100644
--- a/docs/source/en/features/2D_tensor_parallel.md
+++ b/docs/source/en/features/2D_tensor_parallel.md
@@ -60,83 +60,9 @@ Given $P=q\times q$ processors, we present the theoretical computation and memor
## Usage
-To enable 2D tensor parallelism for our model, e.g. on 4 GPUs, we need to configure the parallelism setting as below.
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=4, mode='2d'),
-))
-```
-Then Colossal-AI will automatically apply 2D parallelism to all the layers from `colossalai.nn`.
-
-Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-Launch Colossal-AI on 4 GPUs and build the model
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-
-m = MLP()
-```
-We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
-```shell
-Weight of the first linear layer: torch.Size([128, 512])
-Weight of the second linear layer: torch.Size([512, 128])
-```
-The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2D parallelism, it becomes `[128, 512]` on each GPU.
-Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`.
-
-We can run the model with some random inputs.
-```python
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-
-x = torch.randn((16, 256), device=get_current_device())
-# partition input
-torch.distributed.broadcast(x, src=0)
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
-x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
-print_rank_0(f'Input: {x.shape}')
-
-x = m(x)
-```
-Then we can see the shapes of activation results.
-```shell
-Input: torch.Size([8, 128])
-Output of the first linear layer: torch.Size([8, 512])
-Output of the second linear layer: torch.Size([8, 128])
-```
-The activation tensors in 2D parallelism are all split in both row and column.
-E.g. the output of the first linear layer has the shape `[8, 512]`, while the second layer has the output of `[8, 128]`.
+Currently the newest version of ColossalAI doesn't support 2D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases.
+For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
+
+For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md).
+
+
diff --git a/docs/source/en/features/2p5D_tensor_parallel.md b/docs/source/en/features/2p5D_tensor_parallel.md
index a81d14f10627..b3cbd1c7c727 100644
--- a/docs/source/en/features/2p5D_tensor_parallel.md
+++ b/docs/source/en/features/2p5D_tensor_parallel.md
@@ -58,86 +58,9 @@ Given $P=q \times q \times d$ processors, we present the theoretical computation
## Usage
-To enable 2.5D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below.
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=8, mode='2.5d', depth=2),
-))
-
-```
-Then Colossal-AI will automatically apply 2.5D parallelism to all the layers from `colossalai.nn`.
-
-Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-Launch Colossal-AI on 8 GPUs and build the model
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-
-m = MLP()
-```
-We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
-```shell
-Weight of the first linear layer: torch.Size([128, 512])
-Weight of the second linear layer: torch.Size([512, 128])
-```
-The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2.5D parallelism, it becomes `[128, 512]` on each GPU.
-Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`.
-
-We can run the model with some random inputs.
-```python
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-
-x = torch.randn((16, 256), device=get_current_device())
-# partition input
-torch.distributed.broadcast(x, src=0)
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)]
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)]
-x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)]
-print_rank_0(f'Input: {x.shape}')
-
-x = m(x)
-```
-Then we can see the shapes of activation results.
-```shell
-Input: torch.Size([4, 128])
-Output of the first linear layer: torch.Size([4, 512])
-Output of the second linear layer: torch.Size([4, 128])
-```
-The activation tensors in 2.5D parallelism are all split by $d \times q$ in the row and $q$ in the column.
-E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`.
-Note, 2.5D parallelism use the same partition method as 2D parallelism for weights, where the difference is the partition of input.
+Currently the newest version of ColossalAI doesn't support 2.5D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases.
+For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
+
+For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md).
+
+
diff --git a/docs/source/en/features/3D_tensor_parallel.md b/docs/source/en/features/3D_tensor_parallel.md
index 0e28f08b23c9..00e6c5fca40c 100644
--- a/docs/source/en/features/3D_tensor_parallel.md
+++ b/docs/source/en/features/3D_tensor_parallel.md
@@ -67,85 +67,9 @@ Given $P=q \times q \times q$ processors, we present the theoretical computation
## Usage
-To enable 3D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below.
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=8, mode='3d'),
-))
-```
-Then Colossal-AI will automatically apply 3D parallelism to all the layers from `colossalai.nn`.
-
-Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-Launch Colossal-AI on 8 GPUs and build the model
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-
-m = MLP()
-```
-We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
-```shell
-Weight of the first linear layer: torch.Size([128, 256])
-Weight of the second linear layer: torch.Size([512, 64])
-```
-The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 3D parallelism, it becomes `[128, 256]` on each GPU.
-Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 64]`.
-
-We can run the model with some random inputs.
-```python
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-
-x = torch.randn((16, 256), device=get_current_device())
-# partition input
-torch.distributed.broadcast(x, src=0)
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)]
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)]
-x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)]
-print_rank_0(f'Input: {x.shape}')
-
-x = m(x)
-```
-Then we can see the shapes of activation results.
-```shell
-Input: torch.Size([4, 128])
-Output of the first linear layer: torch.Size([4, 512])
-Output of the second linear layer: torch.Size([4, 128])
-```
-The activation tensors in 3D parallelism are all split by $q^2$ in the row and $q$ in the column.
-E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`.
-Note, although the results of 3D parallelism have the same shape as that of 2.5D parallelism for weights here, the content of each partition is different.
+Currently the newest version of ColossalAI doesn't support 3D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases.
+For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md).
+
+For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md).
+
+
diff --git a/docs/source/en/features/gradient_handler.md b/docs/source/en/features/gradient_handler.md
index 757016fcb53a..66e5e3a9dfbd 100644
--- a/docs/source/en/features/gradient_handler.md
+++ b/docs/source/en/features/gradient_handler.md
@@ -28,8 +28,8 @@ To implement a customized gradient handler, you need to follow these steps.
3. implement `handle_gradient` method.
```python
-from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine.gradient_handler import BaseGradientHandler
+from colossalai.legacy.registry import GRADIENT_HANDLER
+from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module
@@ -61,3 +61,4 @@ to demonstrate the use of gradient handler. In this example, we used `DataParall
```shell
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
```
+
diff --git a/docs/source/en/features/mixed_precision_training.md b/docs/source/en/features/mixed_precision_training.md
index 8579d586ed5f..164b2a21598c 100644
--- a/docs/source/en/features/mixed_precision_training.md
+++ b/docs/source/en/features/mixed_precision_training.md
@@ -267,7 +267,7 @@ from pathlib import Path
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_dataloader
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR
from timm.models import vit_base_patch16_224
from torchvision import datasets, transforms
diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md
index 30654b0b0195..cb19f9815bf2 100644
--- a/docs/source/en/features/pipeline_parallel.md
+++ b/docs/source/en/features/pipeline_parallel.md
@@ -1,14 +1,15 @@
# Pipeline Parallel
-Author: Guangyang Lu, Hongxin Liu, Yongbin Li
+Author: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang
**Prerequisite**
-- [Define Your Configuration](../basics/define_your_config.md)
-- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
-- [Configure Parallelization](../basics/configure_parallelization.md)
+- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
+- [Use Booster to Training](../basics/booster_api.md)
+- [Shardformer](../features/shardformer.md)
+- [Plugin of Booster](../basics/booster_plugins.md)
**Example Code**
-- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel)
+- [Fine-tune Bert with pipeline](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py)
**Related Paper**
- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
@@ -17,7 +18,7 @@ Author: Guangyang Lu, Hongxin Liu, Yongbin Li
## Quick introduction
-In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example.
+In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use bert model and glue dataset as example.
## Table Of Content
@@ -25,7 +26,7 @@ In this tutorial we will cover:
1. Introduction of 1F1B pipeline.
2. Usage of non-interleaved and interleaved schedule.
-3. Training ResNet with pipeline.
+3. Finetune Bert with pipeline.
## Introduction of 1F1B pipeline
@@ -60,100 +61,158 @@ In this schedule, each device can perform computation for multiple subsets of la
This mode is both memory-efficient and time-efficient.
-## Usage of non-interleaved and interleaved schedule
+## Colossal-AI's Implementation
-In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`).
+In Colossal-AI, pipeline parallelism relies on the `scheduler` and [`Shardformer`](../features/shardformer.md). We provide both non-interleaved (`OneForwardOneBackwardSchedule`) and interleaved (`InterleavedSchedule`) schedules. While `Shardformer` implements layer splitting for models and replaces the `forward` function of the model to make it compatible with the scheduler.
-You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you.
+In Colossal-AI, the `HybridParallelPlugin` encapsulates pipeline execution strategies. It manages pipeline parallel communication groups and a scheduler. When boosting the model with this plugin, the model's layers are split by calling the `shardformer.optimize` function, and then `execute_pipeline` is called to execute the model in segments using `OneForwardOneBackwardSchedule` which is default scheduler used in `HybridParallelPlugin`, and `InterleavedSchedule` will be integrated later.
-## Training ResNet with pipeline
+You can customize your parallel strategy by setting parameters for the `HybridParallelPlugin`.
-Let's build the `ResNet` model first with Colossal PipelinableContext:
+For more usage details, please refer to the [documentation](../basics/booster_plugins.md) for `HybridParallelPlugin`.
+
+## Fine-tune Bert with pipeline
+
+First, we define the necessary training components, including model, dataloader, optimizer, lr_scheduler, criterion:
```python
-import os
-from typing import Callable, List, Optional, Type, Union
+import argparse
+from typing import Callable, List, Union
+
import torch
import torch.nn as nn
+from data import GLUEDataBuilder
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import (
+ AlbertForSequenceClassification,
+ AutoConfig,
+ BertForSequenceClassification,
+ get_linear_schedule_with_warmup,
+)
+
import colossalai
-import colossalai.nn as col_nn
+from colossalai.booster import Booster
+from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.optimizer import HybridAdam
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
-from colossalai.utils import MultiTimer, get_dataloader
-from colossalai.context import ParallelMode
-from colossalai.pipeline.pipelinable import PipelinableContext
+# Define some config
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+
+coordinator = DistCoordinator()
+
+def move_to_cuda(batch):
+ return {k: v.cuda() for k, v in batch.items()}
+
+
+# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'.
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+# Define optimizer
+lr = LEARNING_RATE
+no_decay = ["bias", "LayerNorm.weight"]
+optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+]
-from titans.dataloader.cifar10 import build_cifar
-from torchvision.models import resnet50
-from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
+optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
-# Define some config
-BATCH_SIZE = 64
-NUM_EPOCHS = 2
-NUM_CHUNKS = 1
-CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
-
-# Train
-disable_existing_loggers()
-parser = colossalai.get_default_parser()
-args = parser.parse_args()
-colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
-logger = get_dist_logger()
-pipelinable = PipelinableContext()
-
-# build model
-with pipelinable:
- model = resnet50()
-```
-Define an execution sequence.
-```python
-exec_seq = [
- 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
- (lambda x: torch.flatten(x, 1), "behind"), 'fc'
-]
-pipelinable.to_layer_list(exec_seq)
+# Define lr_scheduler
+total_steps = len(train_dataloader) * NUM_EPOCHS
+num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+)
+
+
+# Define Bert model
+model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda()
+
+# Define a dataloader
+data_builder = GLUEDataBuilder(model_name,
+ plugin,
+ args.task,
+ train_batch_size=BATCH_SIZE,
+ eval_batch_size=BATCH_SIZE)
+train_dataloader = data_builder.train_dataloader()
```
-Partition the model into pipeline.
+Define a booster with the `HybridParallelPlugin`.
```python
-model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
+plugin = HybridParallelPlugin(tp_size=1,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ zero_stage=1,
+ precision='fp16',
+ initial_scale=1)
+booster = Booster(plugin=plugin)
```
-In this tutorial, we use `Trainer` to train `ResNet`:
+Boost these train componts with the booster created.
```python
-# build criterion
-criterion = nn.CrossEntropyLoss()
-
-# optimizer
-optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
-# build dataloader
-root = os.environ.get('DATA', './data')
-train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)
-
-lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
-engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
- train_dataloader, test_dataloader,
- lr_scheduler)
-timer = MultiTimer()
+model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
+ optimizer,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
+```
-trainer = Trainer(engine=engine, timer=timer, logger=logger)
+Train the model at last.
-hook_list = [
- hooks.LossHook(),
- hooks.AccuracyHook(col_nn.metric.Accuracy()),
- hooks.LogMetricByEpochHook(logger),
- hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
-]
-
-trainer.fit(train_dataloader=train_dataloader,
- epochs=NUM_EPOCHS,
- test_dataloader=test_dataloader,
- test_interval=1,
- hooks=hook_list,
- display_progress=True)
+```python
+# Define a train function
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
+ train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+
+ is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
+ total_step = len(train_dataloader)
+
+ model.train()
+ optimizer.zero_grad()
+ # convert train_dataloader to a iterator
+ train_dataloader_iter = iter(train_dataloader)
+ with tqdm(range(total_step),
+ desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
+ disable=not (is_pp_last_stage)) as pbar:
+ # Forward pass
+ for _ in pbar:
+ outputs = booster.execute_pipeline(train_dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+# Train model
+for epoch in range(NUM_EPOCHS):
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
```
-We use `2` pipeline stages and the batch will be split into `4` micro batches.
+We use `2` pipeline stages and the micro batches is 1. (these parameters can be configured to an appropriate value)
+
diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md
new file mode 100644
index 000000000000..4abfff8a3cfa
--- /dev/null
+++ b/docs/source/en/features/shardformer.md
@@ -0,0 +1,281 @@
+# Shardformer
+
+Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer)
+
+**Prerequisite**
+- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
+- [Booster API](../basics/booster_api.md)
+- [Booster Plugins](../basics/booster_plugins.md)
+
+**Example Code**
+- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples)
+- [Enabling Shardformer using HybridPrallelPlugin](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)
+
+**Related Paper**
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
+- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
+- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691)
+- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
+- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198)
+
+## Introduction
+
+When training large transformer models such as LLaMa-2 70B or OPT 175B, model parallelism methods that divide a huge model into smaller shards, including tensor parallelism or pipeline parallism, are essential so as to meet the limitation of GPU memory.
+However, manually cutting model and rewriting its forward/backword logic could be difficult for users who are not familiar with distributed training.
+Meanwhile, the Huggingface transformers library has gradually become users' first choice of model source, and most mainstream large models have been open-sourced in Huggingface transformers model library.
+
+Out of this motivation, the ColossalAI team develops **Shardformer**, a feature that automatically does preparation of model parallelism (tensor parallelism/pipeline parallelism) for popular transformer models in HuggingFace.
+This module aims to make parallelization hassle-free for users who are not from the system background.
+Within a few lines of codes, users can turn a model into a state ready for distributed training.
+Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass.
+
+## Supporting Information
+
+Model/Feature Compatibility Matrix:
+
+
+
+ | Model/Feature |
+ Tensor Parallel |
+ Pipeline Parallel |
+ Lazy Initialization |
+ xFormers |
+ Flash Attention 2 |
+ JIT Fused Operators |
+ Fused LayerNorm |
+ Sequence Parallel |
+ Sequence Overlap |
+
+
+ | Llama V1/V2 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | OPT |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | BLOOM |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | ChatGLM 2 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | BERT |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | GPT 2 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | T5 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | ViT |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | Whisper |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | SAM |
+ ✔️ |
+ ❌ |
+ ❌ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | Blip2 |
+ ✔️ |
+ ❌ |
+ ❌ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ |
+
+
+
+List of model families we plan to support in the near future:
+- RoBERTa
+- ALBERT
+- ERNIE
+- GPT Neo
+- GPT-J
+- BEiT
+- SwinTransformer V1/V2
+- qwen
+
+The support matrix will grow larger as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project.
+
+## Usage
+
+### Shardformer Configuration
+
+The configuration of Shardformer is controlled by class `ShardConfig`:
+
+{{ autodoc:colossalai.shardformer.ShardConfig }}
+
+If you want to enable Apex Fused Layernorm, please install `apex`.
+If you want to enable the usage of flash attention, please install `flash_attn`.
+In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
+
+### Enabling Shardformer
+
+#### 1. Enabling Shardformer Through Booster (Recommended)
+
+Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
+The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
+
+More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).
+
+[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.
+
+
+#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
+
+You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
+
+[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
+is an example on how to trigger `Shardformer` through calling Shardformer APIs.
+
+### Precautions
+
+1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method.
+
+2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
+
+3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
+ ```python
+ from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
+ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
+ ```
+ when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
+
+## How Shardformer Works
+
+Generally, Shardformer works through the following four kinds of *replacements*:
+
+1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
+The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters.
+Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism.
+Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
+
+2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training.
+For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`.
+
+3. Replacing the `forward` methods implemented by original Huggingface
+Transformers libraries with our customized `forward` methods.
+This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages.
+Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method.
+
+4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer).
+By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of.
+To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them.
+All other parameters are released so as to liberate memory usage.
+As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved.
+
+All of these replacements are implemented with manually written policies and forward functions.
+If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details.
+
+### Sequence Parallelism
+
+Sequence parallelism is a special optimization method supported by `Shardformer`. Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism is only used along with 1D tensor parallelism to further reduce memory occupation of activation tensors during computation.
+
+1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradients from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the devices.
+
+2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs along sequence dimension during forward, and Reduce-Scatter to split the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to split the output of `Row Linear` layer of tensor parallel to all devices along sequence dimension, and All-Gather to get the whole gradient during backward.
+
+3. NCCL's implementation of All-Reduce adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared with sequence parallelism and tensor parallelism, it does not introduce additional communication overhead.
+
+4. One important thing to note is that when using sequence parallelism along with `Column Linear` module of tensor parallelism, the complete input needs to be obtained during the backward computation of gradients. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, in the shape of $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, it is possible to overlap the gradient computation with the All-Gather communication operation in our implementation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`).
+
+
+
diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
index 059eb014affd..c4b0f6557926 100644
--- a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
@@ -81,14 +81,14 @@ Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管
## 梯度 Handler
梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承
-`colossalai.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
+`colossalai.legacy.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
如果数据并行被检测到,梯度 handler 会被自动添加进 engine。
你可以添加你自己的梯度 handler,如下所示:
```python
-from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.registry import GRADIENT_HANDLER
+from colossalai.legacy.engine import BaseGradientHandler
@GRADIENT_HANDLER.register_module
class YourGradientHandler(BaseGradientHandler):
@@ -109,4 +109,5 @@ gradient_handlers = [
## Schedule
Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。
-如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
+如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.legacy.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
+
diff --git a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
index 3f85d50454ae..dfd1e2910b4e 100644
--- a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
+++ b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
@@ -160,7 +160,7 @@ for mn, module in model.named_modules():
```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
- from colossalai.nn.parallel import GeminiDDP
+ from colossalai.zero import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 6c6dcf6e850d..3f57f39f2838 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -36,14 +36,14 @@ import torch
import torch.nn as nn
from colossalai import nn as col_nn
from colossalai.amp import AMP_TYPE
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils.timer import MultiTimer
from model_zoo.gpt import GPTLMLoss
from torch.nn import functional as F
@@ -273,3 +273,4 @@ def train():
return_output_label=False,
)
```
+
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index 495c7fa36cc1..5ef863dcd423 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -32,11 +32,11 @@ import colossalai
import colossalai.nn as col_nn
import torch
import torch.nn as nn
-from colossalai.builder import build_pipeline_model
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.builder import build_pipeline_model
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from timm.models import vision_transformer as vit
from torchvision import transforms
@@ -48,17 +48,17 @@ from torchvision.datasets import CIFAR10
总的来说, 我们提供3种方法来建立一个流水并行的模型:
-1. `colossalai.builder.build_pipeline_model_from_cfg`
-2. `colossalai.builder.build_pipeline_model`
+1. `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+2. `colossalai.legacy.builder.build_pipeline_model`
3. 自己按阶段拆分模型
当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。
-`colossalai.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。
+`colossalai.legacy.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。
-如果你熟悉 `PyTorch`, 你可以使用 `colossalai.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。
+如果你熟悉 `PyTorch`, 你可以使用 `colossalai.legacy.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。
-在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.builder.build_pipeline_model()` 来建立流水线模型。
+在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.legacy.builder.build_pipeline_model()` 来建立流水线模型。
当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor`。
@@ -244,3 +244,4 @@ def train():
hooks=hook_list,
display_progress=True)
```
+
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index 5ad08392049e..f7dd8d477a66 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -73,8 +73,8 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import LinearWarmupLR
-from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.nn.metric import Accuracy
+from colossalai.legacy.trainer import Trainer, hooks
```
- 其他模块
@@ -256,8 +256,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
### 构建流水线模型 (`/hybrid_parallel/model/vit.py`)
Colossal-AI 提供了两种从现有模型构建流水线模型的方法。
-- `colossalai.builder.build_pipeline_model_from_cfg`
-- `colossalai.builder.build_pipeline_model`
+- `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+- `colossalai.legacy.builder.build_pipeline_model`
此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。
```python
@@ -266,11 +266,11 @@ from typing import Callable
import inspect
import torch
from colossalai import nn as col_nn
-from colossalai.registry import LAYERS, MODELS
+from colossalai.legacy.registry import LAYERS, MODELS
from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
from torch import dtype, nn
from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
@MODELS.register_module
@@ -380,7 +380,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
#### 导入模块
```python
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.utils import MultiTimer
import os
@@ -589,3 +589,4 @@ torchrun --standalone --nproc_per_node train_hybrid.py --config ./co
# If your torch >= 1.9.0
# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py
```
+
diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md
index b2235b73bca1..c59d75d321c0 100644
--- a/docs/source/zh-Hans/basics/booster_api.md
+++ b/docs/source/zh-Hans/basics/booster_api.md
@@ -1,6 +1,6 @@
-# booster 使用
+# Booster API
-作者: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https://github.com/CjhHa1)
+作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003)
**预备知识:**
@@ -11,17 +11,20 @@
-- [使用 booster 训练](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
+- [使用Booster在CIFAR-10数据集上训练ResNet](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
+- [使用Booster在RedPajama数据集上训练Llama-1/2](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
## 简介
-在我们的新设计中, `colossalai.booster` 代替 `colossalai.initialize` 将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 `colossalai.booster` 是您进入训练循环前的基本操作。
+在我们的新设计中, `colossalai.booster` 代替 `colossalai.initialize` 将特征(例如,模型、优化器、数据加载器)无缝注入到您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 `colossalai.booster` 是您进入训练流程前的正常操作。
在下面的章节中,我们将介绍 `colossalai.booster` 是如何工作的以及使用时我们要注意的细节。
### Booster 插件
Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 gemini 加速方案)。目前支持的插件如下:
+**_HybridParallelPlugin:_** HybirdParallelPlugin 插件封装了混合并行的加速解决方案。它提供的接口可以在张量并行,流水线并行以及两种数据并行方法(DDP, Zero)间进行任意的组合。
+
**_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
@@ -30,6 +33,7 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。
+若想了解更多关于插件的用法细节,请参考[Booster 插件](./booster_plugins.md)章节。
### Booster 接口
@@ -39,7 +43,7 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
## 使用方法及示例
-在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用`colossalai.booster` 将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。
+在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用`booster.boost` 将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。
以下是一个伪代码示例,将展示如何使用我们的 booster API 进行模型训练:
@@ -53,15 +57,21 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
+ # launch colossalai
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+
+ # create plugin and objects for training
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
+
+ # use booster.boost to wrap the training objects
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
+ # do training as normal, except that the backward should be called by booster
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
@@ -70,14 +80,16 @@ def train():
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
+ optimizer.zero_grad()
+ # checkpointing using booster api
save_path = "./model"
- booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
+ booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
new_model = resnet18()
booster.load_model(new_model, save_path)
```
-[更多的设计细节请参考](https://github.com/hpcaitech/ColossalAI/discussions/3046)
+更多的Booster设计细节请参考这一[页面](https://github.com/hpcaitech/ColossalAI/discussions/3046)
diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md
index 4ed049dcf44f..1ff2e330521c 100644
--- a/docs/source/zh-Hans/basics/booster_checkpoint.md
+++ b/docs/source/zh-Hans/basics/booster_checkpoint.md
@@ -13,32 +13,56 @@
{{ autodoc:colossalai.booster.Booster.save_model }}
-模型在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是要保存的 checkpoint 的路径。 如果 `shard=False`,它就是文件。 否则, 它就是文件夹。如果 `shard=True`,checkpoint 将以分片方式保存。当 checkpoint 太大而无法保存在单个文件中时,这很有用。我们的分片 checkpoint 格式与 [huggingface/transformers](https://github.com/huggingface/transformers) 兼容。
+模型在保存前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是要保存的 checkpoint 的路径。 如果 `shard=False`,它就是文件。 否则, 它就是文件夹。如果 `shard=True`,checkpoint 将以分片方式保存,在 checkpoint 太大而无法保存在单个文件中时会很实用。我们的分片 checkpoint 格式与 [huggingface/transformers](https://github.com/huggingface/transformers) 兼容,所以用户可以使用huggingface的`from_pretrained`方法从分片checkpoint加载模型。
{{ autodoc:colossalai.booster.Booster.load_model }}
-模型在加载前必须被 `colossalai.booster.Booster` 加速。它会自动检测 checkpoint 格式,并以相应的方式加载。
+模型在加载前必须被 `colossalai.booster.Booster` 封装。它会自动检测 checkpoint 格式,并以相应的方式加载。
+
+如果您想从Huggingface加载预训练好的模型,但模型太大以至于无法在单个设备上通过“from_pretrained”直接加载,推荐的方法是将预训练的模型权重下载到本地,并在封装模型后使用`booster.load`直接从本地路径加载。为了避免内存不足,模型需要在`Lazy Initialization`的环境下初始化。以下是示例伪代码:
+```python
+from colossalai.lazy import LazyInitContext
+from huggingface_hub import snapshot_download
+...
+
+# Initialize model under lazy init context
+init_ctx = LazyInitContext(default_device=get_current_device)
+with init_ctx:
+ model = LlamaForCausalLM(config)
+
+...
+
+# Wrap the model through Booster.boost
+model, optimizer, _, _, _ = booster.boost(model, optimizer)
+
+# download huggingface pretrained model to local directory.
+model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp")
+
+# load model using booster.load
+booster.load(model, model_dir)
+...
+```
## 优化器 Checkpoint
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
-优化器在保存前必须被 `colossalai.booster.Booster` 加速。
+优化器在保存前必须被 `colossalai.booster.Booster` 封装。
{{ autodoc:colossalai.booster.Booster.load_optimizer }}
-优化器在加载前必须被 `colossalai.booster.Booster` 加速。
+优化器在加载前必须被 `colossalai.booster.Booster` 封装。
## 学习率调度器 Checkpoint
{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}
-学习率调度器在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径.
+学习率调度器在保存前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是 checkpoint 文件的本地路径.
{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}
-学习率调度器在加载前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径.
+学习率调度器在加载前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是 checkpoint 文件的本地路径.
## Checkpoint 设计
diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md
index 0f355c43901c..0ad1cacab151 100644
--- a/docs/source/zh-Hans/basics/booster_plugins.md
+++ b/docs/source/zh-Hans/basics/booster_plugins.md
@@ -1,6 +1,6 @@
# Booster 插件
-作者: [Hongxin Liu](https://github.com/ver217)
+作者: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003)
**前置教程:**
- [Booster API](./booster_api.md)
@@ -11,10 +11,11 @@
我们现在提供以下插件:
-- [Low Level Zero 插件](#low-level-zero-plugin): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
-- [Gemini 插件](#gemini-plugin): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
-- [Torch DDP 插件](#torch-ddp-plugin): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。
-- [Torch FSDP 插件](#torch-fsdp-plugin): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。
+- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
+- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
+- [Torch DDP 插件](#torch-ddp-插件): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。
+- [Torch FSDP 插件](#torch-fsdp-插件): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。
+- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 它为Shardformer,流水线管理器,混合精度运算,TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行,流水线并行以及数据并行(DDP, Zero)间任意组合并行训练策略,同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。
更多插件即将推出。
@@ -43,8 +44,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
兼容性问题将在未来修复。
-> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。
-
### Gemini 插件
这个插件实现了基于Chunk内存管理和异构内存管理的 Zero-3。它可以训练大型模型而不会损失太多速度。它也不支持局部梯度累积。更多详细信息,请参阅 [Gemini 文档](../features/zero_with_chunk.md).
@@ -70,4 +69,23 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
+
+### Hybrid Parallel 插件
+
+这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分:
+
+1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。
+
+2. 混合精度训练:插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。
+
+3. Torch DDP: 当流水线并行和Zero不被使用的时候,插件会自动采用Pytorch DDP作为数据并行的策略。更多关于Torch DDP的参数配置的详细信息请参考 [Pytorch DDP 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel)。
+
+4. Zero: 在初始化插件的时候,可以通过将`zero_stage`参数设置为1或2来让插件采用Zero 1/2作为数据并行的策略。Zero 1可以和流水线并行策略同时使用, 而Zero 2则不可以和流水线并行策略同时使用。更多关于Zero的参数配置的详细信息请参考 [Low Level Zero 插件](#low-level-zero-插件).
+
+> ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。
+
+> ⚠ 该插件当前只对模型和优化器支持分片的checkpoint方法。不分片的checkpoint方法会在未来的版本中被支持。
+
+{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
+
diff --git a/docs/source/zh-Hans/basics/engine_trainer.md b/docs/source/zh-Hans/basics/engine_trainer.md
index a35bd87c44e1..ed5100299212 100644
--- a/docs/source/zh-Hans/basics/engine_trainer.md
+++ b/docs/source/zh-Hans/basics/engine_trainer.md
@@ -61,7 +61,7 @@ Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除
```python
from colossalai.logging import get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
# build components and initialize with colossalai.initialize
...
@@ -104,7 +104,7 @@ trainer.fit(
```python
from colossalai.logging import get_dist_logger
-from colossalai.trainer import hooks
+from colossalai.legacy.trainer import hooks
class LogMessageHook(hooks.BaseHook):
@@ -340,8 +340,8 @@ for epoch in range(gpc.config.NUM_EPOCHS):
```python
-from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.nn.metric import Accuracy
+from colossalai.legacy.trainer import Trainer, hooks
# create a trainer object
@@ -384,3 +384,4 @@ python -m torch.distributed.launch --nproc_per_node --master_addr loc
# with trainer
python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
```
+
diff --git a/docs/source/zh-Hans/basics/model_checkpoint.md b/docs/source/zh-Hans/basics/model_checkpoint.md
index a5374b7509c9..4a49d373a2a4 100644
--- a/docs/source/zh-Hans/basics/model_checkpoint.md
+++ b/docs/source/zh-Hans/basics/model_checkpoint.md
@@ -41,7 +41,7 @@ for epoch in range(num_epochs):
#### 用 trainer 保存
```python
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...)
@@ -61,3 +61,4 @@ model = ...
load_checkpoint('xxx.pt', model)
... # train or test
```
+
diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md
index 4dd45e8783c3..93fe9ea99422 100644
--- a/docs/source/zh-Hans/features/1D_tensor_parallel.md
+++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md
@@ -6,8 +6,8 @@
- [定义配置文件](../basics/define_your_config.md)
- [并行配置](../basics/configure_parallelization.md)
-**示例代码**
-- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)
+**示例代码**xw
+- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples)
**相关论文**
- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf)
@@ -41,80 +41,10 @@ $$
| :-: | :-: | :-: | :-: | :-: |
| $O(1/P)$ | $O(1/P)$ | $O(1)$ | $O(2(P-1)/P)$ | $O(2(P-1))$ |
-## 使用
-
-为了使模型能够实现一维张量并行, 如在2个 GPU 上, 我们需要配置如下的并行设置。
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=2, mode='1d'),
-))
-```
-
-然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用1D张量并行。
-
-让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-
-在2个 GPU 上启动 Colossal-AI 并建立模型。
-
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-m = MLP()
-```
-我们将会看到 MLP 模型中被划分的参数(如权重)的形状。
-```shell
-Weight of the first linear layer: torch.Size([256, 512])
-Weight of the second linear layer: torch.Size([512, 256])
-```
-第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过列-并行分割,它变成了 `[256, 512]`。
-同样地,第二个行并行层将权重 `[1024, 256]` 划分为 `[512, 256]`。
-
-我们可以用一些随机输入来运行这个模型。
-```python
-from colossalai.utils import get_current_device
+## 使用
-x = torch.randn((16, 256), device=get_current_device())
-torch.distributed.broadcast(x, src=0) # synchronize input
+在ColossalAI最新的版本中,1D张量并行由`Shardformer`功能实现。
+关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。
-x = m(x)
-```
-然后我们可以看到 activation 结果的形状。
-```shell
-Output of the first linear layer: torch.Size([16, 512])
-Output of the second linear layer: torch.Size([16, 256])
-```
-第一个线性层的输出被划分成2块 (每个形状为 `[16, 512]`), 而第二层在整个 GPU 上的输出是相同的。
+
diff --git a/docs/source/zh-Hans/features/2D_tensor_parallel.md b/docs/source/zh-Hans/features/2D_tensor_parallel.md
index f163432ecceb..a8e5cf4bfb47 100644
--- a/docs/source/zh-Hans/features/2D_tensor_parallel.md
+++ b/docs/source/zh-Hans/features/2D_tensor_parallel.md
@@ -60,82 +60,8 @@ $$
## 使用
-为了使我们的模型能够实现二维张量并行,例如在4个 GPU 上,我们需要配置如下的并行设置。
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=4, mode='2d'),
-))
-```
-然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2D张量并行。
-
-让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-在4个 GPU 上启动 Colossal-AI 并建立模型。
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-
-m = MLP()
-```
-我们将会看到 MLP 模型中被划分的参数(如权重)的形状。
-```shell
-Weight of the first linear layer: torch.Size([128, 512])
-Weight of the second linear layer: torch.Size([512, 128])
-```
-第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。
-同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`.
-
-我们可以用一些随机输入来运行这个模型。
-```python
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-
-x = torch.randn((16, 256), device=get_current_device())
-# partition input
-torch.distributed.broadcast(x, src=0)
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
-x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
-print_rank_0(f'Input: {x.shape}')
-
-x = m(x)
-```
-然后我们可以看到 activation 结果的形状。
-```shell
-Input: torch.Size([8, 128])
-Output of the first linear layer: torch.Size([8, 512])
-Output of the second linear layer: torch.Size([8, 128])
-```
-2D并行中的 activation 张量都是同时在行和列分割的。例如,第一个线性层的输出是 `[8, 512]`, 而第二层的输出为 `[8, 128]`。
+ColossalAI的最新版本还暂不支持2D张量并行,但2D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。
+
+对于老版本ColossalAI的用户,2D张量并行的用法请参考[ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。
+
+
diff --git a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md
index 5f15202729a7..6b0f1a301804 100644
--- a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md
+++ b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md
@@ -57,89 +57,8 @@ $$
## 使用
-为了使我们的模型能够实现2.5D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。
-
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=8, mode='2.5d', depth=2),
-))
-
-```
-
-然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2.5D张量并行。
-
-让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。
-
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-在8个 GPU 上启动 Colossal-AI 并建立模型。
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-
-m = MLP()
-```
-我们将会看到 MLP 模型中被划分的参数(如权重)的形状。
-```shell
-Weight of the first linear layer: torch.Size([128, 512])
-Weight of the second linear layer: torch.Size([512, 128])
-```
-
-第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2.5D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。
-同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`.
-
-我们可以用一些随机输入来运行这个模型。
-```python
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-
-x = torch.randn((16, 256), device=get_current_device())
-# partition input
-torch.distributed.broadcast(x, src=0)
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)]
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)]
-x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)]
-print_rank_0(f'Input: {x.shape}')
-
-x = m(x)
-```
-然后我们可以看到 activation 结果的形状。
-```shell
-Input: torch.Size([4, 128])
-Output of the first linear layer: torch.Size([4, 512])
-Output of the second linear layer: torch.Size([4, 128])
-```
-2.5D并行中的 activation 张量都是同时在$d \times q$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。
-注意,2.5D并行使用与2D并行相同的划分方法来处理权重,区别在于对输入的划分。
+ColossalAI的最新版本还暂不支持2.5D张量并行,但2.5D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。
+
+对于老版本ColossalAI的用户,2.5D张量并行的用法请参考[ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。
+
+
diff --git a/docs/source/zh-Hans/features/3D_tensor_parallel.md b/docs/source/zh-Hans/features/3D_tensor_parallel.md
index 5ce0cdf6c068..f6154559ec28 100644
--- a/docs/source/zh-Hans/features/3D_tensor_parallel.md
+++ b/docs/source/zh-Hans/features/3D_tensor_parallel.md
@@ -67,88 +67,8 @@ $$
## 使用
-为了使我们的模型能够实现3D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。
-
-```python
-CONFIG = dict(parallel=dict(
- data=1,
- pipeline=1,
- tensor=dict(size=8, mode='3d'),
-))
-```
-然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用3D张量并行。
-
-让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。
-
-```python
-import colossalai
-import colossalai.nn as col_nn
-import torch
-from colossalai.utils import print_rank_0
-
-class MLP(torch.nn.Module):
- def __init__(self, dim: int = 256):
- super().__init__()
- intermediate_dim = dim * 4
- self.dense_1 = col_nn.Linear(dim, intermediate_dim)
- print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
- self.activation = torch.nn.GELU()
- self.dense_2 = col_nn.Linear(intermediate_dim, dim)
- print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
- self.dropout = col_nn.Dropout(0.1)
-
- def forward(self, x):
- x = self.dense_1(x)
- print_rank_0(f'Output of the first linear layer: {x.shape}')
- x = self.activation(x)
- x = self.dense_2(x)
- print_rank_0(f'Output of the second linear layer: {x.shape}')
- x = self.dropout(x)
- return x
-```
-在8个 GPU 上启动 Colossal-AI 并建立模型。
-```python
-parser = colossalai.get_default_parser()
-colossalai.launch(config=CONFIG,
- rank=args.rank,
- world_size=args.world_size,
- local_rank=args.local_rank,
- host=args.host,
- port=args.port)
-
-m = MLP()
-```
-我们将会看到 MLP 模型中被划分的参数(如权重)的形状。
-```shell
-Weight of the first linear layer: torch.Size([128, 256])
-Weight of the second linear layer: torch.Size([512, 64])
-```
-
-第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过3D并行划分后,它在每个 GPU 上变成了 `[128, 256]` 。
-同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 64]`.
-
-我们可以用一些随机输入来运行这个模型。
-
-```python
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-
-x = torch.randn((16, 256), device=get_current_device())
-# partition input
-torch.distributed.broadcast(x, src=0)
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)]
-x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)]
-x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)]
-print_rank_0(f'Input: {x.shape}')
-
-x = m(x)
-```
-然后我们可以看到 activation 结果的形状。
-```shell
-Input: torch.Size([4, 128])
-Output of the first linear layer: torch.Size([4, 512])
-Output of the second linear layer: torch.Size([4, 128])
-```
-3D并行中的 activation 张量都是同时在$q^2$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。
-注意,虽然这里3D并行的结果与2.5D并行的结果形状相同,但每个划分的内容是不同的。
+ColossalAI的最新版本还暂不支持3D张量并行,但3D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。
+
+对于老版本ColossalAI的用户,3D张量并行的用法请参考[ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。
+
+
diff --git a/docs/source/zh-Hans/features/gradient_handler.md b/docs/source/zh-Hans/features/gradient_handler.md
index 701c60fed57f..3b1140409ba8 100644
--- a/docs/source/zh-Hans/features/gradient_handler.md
+++ b/docs/source/zh-Hans/features/gradient_handler.md
@@ -25,8 +25,8 @@
3. 实现 `handle_gradient`
```python
-from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine.gradient_handler import BaseGradientHandler
+from colossalai.legacy.registry import GRADIENT_HANDLER
+from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module
@@ -57,3 +57,4 @@ gradient_handler = [dict(type='MyGradientHandler')]
```shell
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
```
+
diff --git a/docs/source/zh-Hans/features/mixed_precision_training.md b/docs/source/zh-Hans/features/mixed_precision_training.md
index a92e7e093015..35a73f1adbcd 100644
--- a/docs/source/zh-Hans/features/mixed_precision_training.md
+++ b/docs/source/zh-Hans/features/mixed_precision_training.md
@@ -245,7 +245,7 @@ from pathlib import Path
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_dataloader
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR
from timm.models import vit_base_patch16_224
from torchvision import datasets, transforms
diff --git a/docs/source/zh-Hans/features/pipeline_parallel.md b/docs/source/zh-Hans/features/pipeline_parallel.md
index 98096b1d7f93..e688020556d8 100644
--- a/docs/source/zh-Hans/features/pipeline_parallel.md
+++ b/docs/source/zh-Hans/features/pipeline_parallel.md
@@ -1,14 +1,15 @@
# 流水并行
-作者: Guangyang Lu, Hongxin Liu, Yongbin Li
+作者: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang
**前置教程**
-- [定义配置文件](../basics/define_your_config.md)
-- [在训练中使用Engine和Trainer](../basics/engine_trainer.md)
-- [并行配置](../basics/configure_parallelization.md)
+- [并行技术](../concepts/paradigms_of_parallelism.md)
+- [Booster API](../basics/booster_api.md)
+- [Shardformer](../features/shardformer.md)
+- [Booster 插件](../basics/booster_plugins.md)
**示例代码**
-- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel)
+- [使用pipeline并行策略微调Bert](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py)
**相关论文**
- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
@@ -17,7 +18,7 @@
## 快速预览
-在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 ResNet 和 CIFAR 为例.
+在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 Bert 和 Glue数据集 为例.
## 目录
@@ -25,7 +26,7 @@
1. 介绍 1F1B 流水线;
2. 使用非交错和交错 schedule;
-3. 使用流水线训练 ResNet。
+3. 使用流水线微调 Bert
## 认识 1F1B 流水线
@@ -59,100 +60,154 @@
这种模式既节省内存又节省时间。
-## 使用schedule
+## Colossal-AI中的实现
-在 Colossal-AI 中, 我们提供非交错(`PipelineSchedule`) 和交错(`InterleavedPipelineSchedule`)schedule。
+在 Colossal-AI 中,流水线并行依赖于 `scheduler` 和 `Shardformer`。我们提供了非交错的(`OneForwardOneBackwardSchedule`)和交错的(`InterleavedSchedule`)两种调度方式。而 Shardformer 实现了对模型的层分割,并替换了模型的 `forward` 函数,使其与调度器兼容。
-你只需要在配置文件中,设置 `NUM_MICRO_BATCHES` 并在你想使用交错schedule的时候,设置 `NUM_CHUNKS`。 如果你确定性地知道每个管道阶段的输出张量的形状,而且形状都是一样的,你可以设置 `tensor_shape` 以进一步减少通信。否则,你可以忽略 `tensor_shape` , 形状将在管道阶段之间自动交换。 我们将会根据用户提供的配置文件,生成一个合适schedule来支持用户的流水并行训练。
+在 Colossal-AI 中,`HybridParallelPlugin` 封装了流水线执行策略。它管理流水线并行通信组和一个 `scheduler`。当使用此插件增强模型时,模型的层将通过调用 `shardformer.optimize` 函数进行分割,然后调用 `execute_pipeline` 使用 `scheduler` 来分别执行模型的各个部分。 `HybridParallelPlugin`暂时只支持`OneForwardOneBackwardSchedule`, `InterleavedSchedule`将会在不久后支持。
-## 使用流水线训练 ResNet
+您可以通过设置 `HybridParallelPlugin` 的参数来自定义您的并行策略。更多使用细节请参考`HybridParallelPlugin`的[使用文档](../basics/booster_plugins.md)。
-我们首先用Colossal PipelinableContext方式建立 `ResNet` 模型:
+## 使用流水线微调 Bert模型
+
+首先我们定义好需要的训练组件,包括`model`, `dataloader`, `optimizer`, `lr_scheduler`, `criterion` 等:
```python
-import os
-from typing import Callable, List, Optional, Type, Union
+import argparse
+from typing import Callable, List, Union
+
import torch
import torch.nn as nn
+from data import GLUEDataBuilder
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import (
+ AlbertForSequenceClassification,
+ AutoConfig,
+ BertForSequenceClassification,
+ get_linear_schedule_with_warmup,
+)
+
import colossalai
-import colossalai.nn as col_nn
+from colossalai.booster import Booster
+from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.optimizer import HybridAdam
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
-from colossalai.utils import MultiTimer, get_dataloader
-from colossalai.context import ParallelMode
-from colossalai.pipeline.pipelinable import PipelinableContext
+# Define some config
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+
+coordinator = DistCoordinator()
+
+def move_to_cuda(batch):
+ return {k: v.cuda() for k, v in batch.items()}
+
+# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'.
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+# Define optimizer
+lr = LEARNING_RATE
+no_decay = ["bias", "LayerNorm.weight"]
+optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+]
-from titans.dataloader.cifar10 import build_cifar
-from torchvision.models import resnet50
-from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
+optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
-# Define some config
-BATCH_SIZE = 64
-NUM_EPOCHS = 2
-NUM_CHUNKS = 1
-CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
-
-# Train
-disable_existing_loggers()
-parser = colossalai.get_default_parser()
-args = parser.parse_args()
-colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
-logger = get_dist_logger()
-pipelinable = PipelinableContext()
-
-# build model
-with pipelinable:
- model = resnet50()
+
+# Define lr_scheduler
+total_steps = len(train_dataloader) * NUM_EPOCHS
+num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+)
+
+
+# Define Bert model
+model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda()
+
+# Define a dataloader
+data_builder = GLUEDataBuilder(model_name,
+ plugin,
+ args.task,
+ train_batch_size=BATCH_SIZE,
+ eval_batch_size=BATCH_SIZE)
+train_dataloader = data_builder.train_dataloader()
```
-给定切分顺序,module直接给出name,部分函数需要手动添加。
+使用`HybridParallelPlugin`初始化一个booster.
```python
-exec_seq = [
- 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
- (lambda x: torch.flatten(x, 1), "behind"), 'fc'
-]
-pipelinable.to_layer_list(exec_seq)
+plugin = HybridParallelPlugin(tp_size=1,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ zero_stage=1,
+ precision='fp16',
+ initial_scale=1)
+booster = Booster(plugin=plugin)
```
-将模型切分成流水线阶段。
+使用`booster`将优化特性注入到训练组件中。
```python
-model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
+model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
+ optimizer,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
```
-我们使用`Trainer`训练`ResNet`:
+最后训练模型
```python
-# build criterion
-criterion = nn.CrossEntropyLoss()
-
-# optimizer
-optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
-# build dataloader
-root = os.environ.get('DATA', './data')
-train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)
-
-lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
-engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
- train_dataloader, test_dataloader,
- lr_scheduler)
-timer = MultiTimer()
-
-trainer = Trainer(engine=engine, timer=timer, logger=logger)
-
-hook_list = [
- hooks.LossHook(),
- hooks.AccuracyHook(col_nn.metric.Accuracy()),
- hooks.LogMetricByEpochHook(logger),
- hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
-]
-
-trainer.fit(train_dataloader=train_dataloader,
- epochs=NUM_EPOCHS,
- test_dataloader=test_dataloader,
- test_interval=1,
- hooks=hook_list,
- display_progress=True)
+# Define a train function
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
+ train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+
+ is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
+ total_step = len(train_dataloader)
+
+ model.train()
+ optimizer.zero_grad()
+ # convert train_dataloader to a iterator
+ train_dataloader_iter = iter(train_dataloader)
+ with tqdm(range(total_step),
+ desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
+ disable=not (is_pp_last_stage)) as pbar:
+ # Forward pass
+ for _ in pbar:
+ outputs = booster.execute_pipeline(train_dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+# Train model
+for epoch in range(NUM_EPOCHS):
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
```
-我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。
+我们使用 `2` 个流水段,并且 batch 将被切分为 `1` 个 micro batches。(这些参数都可根据实际情况设置为合适的值)
+
diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md
new file mode 100644
index 000000000000..fe0e7a63ba44
--- /dev/null
+++ b/docs/source/zh-Hans/features/shardformer.md
@@ -0,0 +1,264 @@
+# Shardformer
+
+Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer)
+
+**预备知识**
+- [并行技术](../concepts/paradigms_of_parallelism.md)
+- [Booster API](../basics/booster_api.md)
+- [Booster 插件](../basics/booster_plugins.md)
+
+**示例代码**
+- [使用Shardformer进行张量并行训练](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples)
+- [通过HybridParallelPlugin使用Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)
+
+**相关论文**
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
+- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
+- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691)
+- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
+- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198)
+
+
+## 简介
+
+在训练LLaMa-2 70B或OPT 175B等大型Transformer模型时,为了满足GPU内存的限制,将大型模型划分为更小的分片的模型并行方法(包括张量并行以及流水线并行)是必不可少的。然而,对于不熟悉分布式训练的用户来说,手动剪切模型并重写其前向/反向逻辑可能很困难。与此同时,Huggingface transformers开源库正在逐渐成为用户模型来源的首选,大部分主流大型模型都已在Huggingface transformers模型库中开源。
+
+出于这种动机,ColossalAI团队开发了**Shardformer**,该功能可以自动为HuggingFace中主流的Transformer模型进行封装,用于张量并行以及流水线并行的训练策略。如此一来,对系统了解不多的用户也可以轻松地在transformers模型上进行并行训练:只需几行代码,用户就可以将模型转变为并行训练的状态。此外,Shardformer也包括了多种优化工具,用于在前向/后向的传递过程中实现加速和节省内存。
+
+## 支持信息
+
+模型/功能 兼容性矩阵:
+
+
+
+ | Model/Feature |
+ Tensor Parallel |
+ Pipeline Parallel |
+ Lazy Initialization |
+ xFormers |
+ Flash Attention 2 |
+ JIT Fused Operators |
+ Fused LayerNorm |
+ Sequence Parallel |
+ Sequence Overlap |
+
+
+ | Llama V1/V2 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | OPT |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | BLOOM |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | ChatGLM 2 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | BERT |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | GPT 2 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+
+
+ | T5 |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | ViT |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | Whisper |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | SAM |
+ ✔️ |
+ ❌ |
+ ❌ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ | Blip2 |
+ ✔️ |
+ ❌ |
+ ❌ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ✔️ |
+ ❌ |
+ ❌ |
+
+
+ |
+
+
+
+我们计划在不久后为Shardformer支持的模型:
+- RoBERTa
+- ALBERT
+- ERNIE
+- GPT Neo
+- GPT-J
+- BEiT
+- SwinTransformer V1/V2
+- qwen
+
+随着未来更多模型和优化工具的出现,我们支持的模型/优化工具将会变得越来越多。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的[Issues](https://github.com/hpcaitech/ColossalAI/issues)板块参与讨论。
+
+## 用法
+
+### Shardformer的参数配置
+
+Shardformer的配置由类`ShardConfig`的参数控制:
+
+{{ autodoc:colossalai.shardformer.ShardConfig }}
+
+如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。
+
+### 启动Shardformer
+
+#### 1. 通过Booster启动Shardformer (推荐)
+
+通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。
+
+更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
+
+
+#### 2. 通过Shardformer API启动Shardformer (不推荐)
+
+您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。
+
+[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
+是一个通过调用Shardformer的API启动`Shardformer`的示例。
+
+
+### 注意事项
+
+1. 当启用流水线并行时,请不要用常规方式(`model(input)`、`loss.backward()`)进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用`booster.execute_pipeline`方法来进行前向/后向传递。
+
+2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
+
+3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类:
+ ```python
+ from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
+ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
+ ```
+ 并且使用这些导入的类初始化模型。
+
+
+## Shardformer的工作原理
+
+通常来说,Shardformer通过以下四种“替换”进行工作:
+
+1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。
+分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。
+
+2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。
+
+3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。
+
+4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。
+如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。
+
+所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。
+
+### 序列并行 Sequence Parallelism
+
+序列并行是`Shardformer`支持的一种特殊的优化方法。在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。
+
+1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。
+
+2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。
+
+3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。
+
+4. 需要注意的一点是,在张量并行的 `Column Linear` 层中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。
+
+
+
diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md
index 7c4147b76457..33c6454ad92c 100644
--- a/examples/images/vit/README.md
+++ b/examples/images/vit/README.md
@@ -3,7 +3,7 @@
Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time.
In our example, we are using pretrained weights of ViT loaded from HuggingFace.
-We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
+We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin (DDP), LowLevelZeroPlugin (Zero1/Zero2), GeminiPlugin (Gemini) and HybridParallelPlugin (any combination of tensor/pipeline/data parallel).
## Run Demo
@@ -25,4 +25,4 @@ You can run benchmark for ViT model by running the following script:
```bash
bash run_benchmark.sh
```
-The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.
\ No newline at end of file
+The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.
diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py
index e4a873a9eb52..e6c52c4e97fd 100644
--- a/examples/images/vit/args.py
+++ b/examples/images/vit/args.py
@@ -1,124 +1,82 @@
from colossalai import get_default_parser
+
def parse_demo_args():
parser = get_default_parser()
- parser.add_argument(
- "--model_name_or_path",
- type=str,
- default="google/vit-base-patch16-224",
- help="Path to pretrained model or model identifier from huggingface.co/models."
- )
- parser.add_argument(
- "--output_path",
- type=str,
- default="./output_model.bin",
- help="The path of your saved model after finetuning."
- )
+ parser.add_argument("--model_name_or_path",
+ type=str,
+ default="google/vit-base-patch16-224",
+ help="Path to pretrained model or model identifier from huggingface.co/models.")
+ parser.add_argument("--output_path",
+ type=str,
+ default="./output_model",
+ help="The path of your saved model after finetuning.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
- help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
- )
- parser.add_argument(
- "--num_epoch",
- type=int,
- default=3,
- help="Number of epochs."
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=32,
- help="Batch size (per dp group) for the training dataloader."
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=3e-4,
- help="Initial learning rate (after the potential warmup period) to use."
- )
- parser.add_argument(
- "--warmup_ratio",
- type=float,
- default=0.3,
- help="Ratio of warmup steps against total training steps."
- )
- parser.add_argument(
- "--weight_decay",
- type=float,
- default=0.1,
- help="Weight decay to use."
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="A seed for reproducible training."
- )
+ help=
+ "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
+ )
+ parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.")
+ parser.add_argument("--batch_size",
+ type=int,
+ default=32,
+ help="Batch size (per dp group) for the training dataloader.")
+ parser.add_argument("--tp_size",
+ type=int,
+ default=1,
+ help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.")
+ parser.add_argument("--pp_size",
+ type=int,
+ default=1,
+ help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.")
+ parser.add_argument("--learning_rate",
+ type=float,
+ default=3e-4,
+ help="Initial learning rate (after the potential warmup period) to use.")
+ parser.add_argument("--warmup_ratio",
+ type=float,
+ default=0.3,
+ help="Ratio of warmup steps against total training steps.")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.")
+ parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
args = parser.parse_args()
return args
+
def parse_benchmark_args():
parser = get_default_parser()
- parser.add_argument(
- "--model_name_or_path",
- type=str,
- default="google/vit-base-patch16-224",
- help="Path to a pretrained model or model identifier from huggingface.co/models."
- )
+ parser.add_argument("--model_name_or_path",
+ type=str,
+ default="google/vit-base-patch16-224",
+ help="Path to a pretrained model or model identifier from huggingface.co/models.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
- help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=8,
- help="Batch size (per dp group) for the training dataloader."
- )
- parser.add_argument(
- "--num_labels",
- type=int,
- default=10,
- help="Number of labels for classification."
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=5e-5,
- help="Initial learning rate (after the potential warmup period) to use."
- )
- parser.add_argument(
- "--weight_decay",
- type=float,
- default=0.0,
- help="Weight decay to use."
- )
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=20,
- help="Total number of training steps to perform."
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="A seed for reproducible training."
- )
- parser.add_argument(
- "--mem_cap",
- type=int,
- default=0,
- help="Limit on the usage of space for each GPU (in GB)."
- )
+ help=
+ "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
+ )
+ parser.add_argument("--batch_size",
+ type=int,
+ default=8,
+ help="Batch size (per dp group) for the training dataloader.")
+ parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.")
+ parser.add_argument("--learning_rate",
+ type=float,
+ default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use.")
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+ parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
+ parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.")
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).")
args = parser.parse_args()
- return args
\ No newline at end of file
+ return args
diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py
index 00fde707b173..77a8ad525056 100644
--- a/examples/images/vit/data.py
+++ b/examples/images/vit/data.py
@@ -1,32 +1,38 @@
import torch
-from torch.utils.data import Dataset
from datasets import load_dataset
+from torch.utils.data import Dataset
+
class BeansDataset(Dataset):
-
- def __init__(self, image_processor, split='train'):
+
+ def __init__(self, image_processor, tp_size=1, split='train'):
super().__init__()
self.image_processor = image_processor
self.ds = load_dataset('beans')[split]
self.label_names = self.ds.features['labels'].names
+ while len(self.label_names) % tp_size != 0:
+ # ensure that the number of labels is multiple of tp_size
+ self.label_names.append(f"pad_label_{len(self.label_names)}")
self.num_labels = len(self.label_names)
self.inputs = []
for example in self.ds:
self.inputs.append(self.process_example(example))
-
+
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return self.inputs[idx]
-
+
def process_example(self, example):
input = self.image_processor(example['image'], return_tensors='pt')
input['labels'] = example['labels']
return input
-
+
def beans_collator(batch):
- return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
- 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)}
+ return {
+ 'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
+ 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)
+ }
diff --git a/examples/images/vit/run_benchmark.sh b/examples/images/vit/run_benchmark.sh
index 2487bf81ee2b..41eab9c5a188 100644
--- a/examples/images/vit/run_benchmark.sh
+++ b/examples/images/vit/run_benchmark.sh
@@ -5,23 +5,20 @@ export BS=8
export MEMCAP=0
export GPUNUM=1
-for BS in 8 32 128
+for BS in 8 32
do
-for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
-do
-for GPUNUM in 1 4
+for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
do
MODEL_PATH="google/vit-base-patch16-224"
torchrun \
--standalone \
- --nproc_per_node ${GPUNUM} \
+ --nproc_per_node 4 \
vit_benchmark.py \
--model_name_or_path ${MODEL_PATH} \
--mem_cap ${MEMCAP} \
--plugin ${PLUGIN} \
--batch_size ${BS}
-
-done
+
done
done
diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh
index 2d140dd6e423..9efe1475956d 100644
--- a/examples/images/vit/run_demo.sh
+++ b/examples/images/vit/run_demo.sh
@@ -5,16 +5,21 @@ pip install -r requirements.txt
MODEL="google/vit-base-patch16-224"
# path for saving model
-OUTPUT_PATH="./output_model.bin"
+OUTPUT_PATH="./output_model"
# plugin(training strategy)
-# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
+# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel"
PLUGIN="gemini"
+#PLUGIN="hybrid_parallel"
+
+# configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel"
+TP_SIZE=2
+PP_SIZE=2
# number of gpus to use
GPUNUM=4
-# batch size per gpu
+# batch size per data parallel group
BS=16
# learning rate
@@ -38,6 +43,8 @@ torchrun \
--output_path ${OUTPUT_PATH} \
--plugin ${PLUGIN} \
--batch_size ${BS} \
+ --tp_size ${TP_SIZE} \
+ --pp_size ${PP_SIZE} \
--num_epoch ${EPOCH} \
--learning_rate ${LR} \
--weight_decay ${WEIGHT_DECAY} \
diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh
index 8606015c0397..570147606636 100644
--- a/examples/images/vit/test_ci.sh
+++ b/examples/images/vit/test_ci.sh
@@ -2,18 +2,15 @@ set -xe
pip install -r requirements.txt
BS=8
-for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
-do
-for GPUNUM in 1 4
+for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
do
torchrun \
--standalone \
- --nproc_per_node ${GPUNUM} \
+ --nproc_per_node 4 \
vit_benchmark.py \
--model_name_or_path "google/vit-base-patch16-224" \
--plugin ${PLUGIN} \
--batch_size ${BS}
done
-done
diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py
index c2293b96ad73..d822fe23ecf0 100644
--- a/examples/images/vit/vit_benchmark.py
+++ b/examples/images/vit/vit_benchmark.py
@@ -1,14 +1,14 @@
import time
import torch
-import tqdm
import transformers
from args import parse_benchmark_args
+from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -24,7 +24,7 @@ def format_num(num: int, bytes=False):
num /= factor
-def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
+def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(batch_size,
num_channels,
height,
@@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
device=torch.cuda.current_device(),
dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
- return pixel_values, labels
+ return dict(pixel_values=pixel_values, labels=labels)
def colo_memory_cap(size_in_GB):
@@ -70,7 +70,8 @@ def main():
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
- model.gradient_checkpointing_enable()
+ if args.grad_checkpoint:
+ model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
@@ -82,34 +83,57 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ elif args.plugin == 'hybrid_parallel':
+ plugin = HybridParallelPlugin(tp_size=2,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ precision='fp16',
+ initial_scale=1)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
+ # Set criterion (loss function)
+ def criterion(outputs, inputs):
+ return outputs.loss
+
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
- model, optimizer, _, _, _ = booster.boost(model, optimizer)
+ model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
# Start training.
logger.info(f"Start testing", ranks=[0])
- progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize()
model.train()
start_time = time.time()
- for _ in range(args.max_train_steps):
-
- pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
- optimizer.zero_grad()
- outputs = model(pixel_values=pixel_values, labels=labels)
- loss = outputs['loss']
- booster.backward(loss, optimizer)
- optimizer.step()
-
- torch.cuda.synchronize()
- progress_bar.update(1)
+ with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
+ for _ in pbar:
+ optimizer.zero_grad()
+ batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
+
+ if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
+ # run pipeline forward backward
+ batch = iter([batch])
+ outputs = booster.execute_pipeline(batch,
+ model,
+ criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ else:
+ outputs = model(**batch)
+ loss = criterion(outputs, None)
+ # Backward
+ booster.backward(loss, optimizer)
+
+ optimizer.step()
+
+ torch.cuda.synchronize()
# Compute Statistics
end_time = time.time()
@@ -124,6 +148,8 @@ def main():
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
+ torch.cuda.empty_cache()
+
if __name__ == "__main__":
main()
diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py
index 4dc0f67f40bf..206d8694b8f5 100644
--- a/examples/images/vit/vit_train_demo.py
+++ b/examples/images/vit/vit_train_demo.py
@@ -1,70 +1,111 @@
+from typing import Any, Callable, Iterator
+
import torch
import torch.distributed as dist
+import torch.nn as nn
import transformers
from args import parse_demo_args
from data import BeansDataset, beans_collator
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
-def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
+def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
+ data_iter: Iterator, booster: Booster):
+ if optimizer is not None:
+ optimizer.zero_grad()
+ if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
+ # run pipeline forward backward when enabling pp in hybrid parallel plugin
+ output_dict = booster.execute_pipeline(data_iter,
+ model,
+ criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ loss, outputs = output_dict['loss'], output_dict['outputs']
+ else:
+ batch = next(data_iter)
+ batch = move_to_cuda(batch, torch.cuda.current_device())
+ outputs = model(**batch)
+ loss = criterion(outputs, None)
+ if optimizer is not None:
+ booster.backward(loss, optimizer)
- torch.cuda.synchronize()
- model.train()
+ return loss, outputs
- with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
- for batch in pbar:
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
+ lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
- # Foward
- optimizer.zero_grad()
- batch = move_to_cuda(batch, torch.cuda.current_device())
- outputs = model(**batch)
- loss = outputs['loss']
+ torch.cuda.synchronize()
- # Backward
- booster.backward(loss, optimizer)
+ num_steps = len(dataloader)
+ data_iter = iter(dataloader)
+ enable_pbar = coordinator.is_master()
+ if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
+ # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar
+ tp_rank = dist.get_rank(booster.plugin.tp_group)
+ dp_rank = dist.get_rank(booster.plugin.dp_group)
+ enable_pbar = tp_rank == 0 and dp_rank == 0 \
+ and booster.plugin.stage_manager.is_last_stage()
+
+ model.train()
+
+ with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar:
+ for _ in pbar:
+ loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster)
optimizer.step()
lr_scheduler.step()
# Print batch loss
- pbar.set_postfix({'loss': loss.item()})
+ if enable_pbar:
+ pbar.set_postfix({'loss': loss.item()})
@torch.no_grad()
-def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
+def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor],
+ eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+ torch.cuda.synchronize()
model.eval()
- accum_loss = torch.zeros(1, device=get_current_device())
- total_num = torch.zeros(1, device=get_current_device())
- accum_correct = torch.zeros(1, device=get_current_device())
+ accum_loss = torch.zeros(1, device=torch.cuda.current_device())
+ total_num = torch.zeros(1, device=torch.cuda.current_device())
+ accum_correct = torch.zeros(1, device=torch.cuda.current_device())
for batch in eval_dataloader:
batch = move_to_cuda(batch, torch.cuda.current_device())
- outputs = model(**batch)
- val_loss, logits = outputs[:2]
- accum_loss += (val_loss / len(eval_dataloader))
- if num_labels > 1:
+ loss, outputs = run_forward_backward(model, None, criterion, iter([batch]), booster)
+
+ to_accum = True
+ if isinstance(booster.plugin, HybridParallelPlugin):
+ # when using hybrid parallel, loss is only collected from last stage of pipeline with tp_rank == 0
+ to_accum = to_accum and (dist.get_rank(booster.plugin.tp_group) == 0)
+ if booster.plugin.pp_size > 1:
+ to_accum = to_accum and booster.plugin.stage_manager.is_last_stage()
+
+ if to_accum:
+ accum_loss += (loss / len(eval_dataloader))
+ logits = outputs["logits"]
preds = torch.argmax(logits, dim=1)
- elif num_labels == 1:
- preds = logits.squeeze()
- labels = batch["labels"]
- total_num += batch["labels"].shape[0]
- accum_correct += (torch.sum(preds == labels))
+ labels = batch["labels"]
+ total_num += batch["labels"].shape[0]
+ accum_correct += (torch.sum(preds == labels))
dist.all_reduce(accum_loss)
dist.all_reduce(total_num)
@@ -94,14 +135,20 @@ def main():
else:
transformers.utils.logging.set_verbosity_error()
+ # Reset tp_size and pp_size to 1 if not using hybrid parallel.
+ if args.plugin != 'hybrid_parallel':
+ args.tp_size = 1
+ args.pp_size = 1
+
# Prepare Dataset
image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
- train_dataset = BeansDataset(image_processor, split='train')
- eval_dataset = BeansDataset(image_processor, split='validation')
+ train_dataset = BeansDataset(image_processor, args.tp_size, split='train')
+ eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation')
+ num_labels = train_dataset.num_labels
# Load pretrained ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path)
- config.num_labels = train_dataset.num_labels
+ config.num_labels = num_labels
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
@@ -110,7 +157,8 @@ def main():
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
- model.gradient_checkpointing_enable()
+ if args.grad_checkpoint:
+ model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
@@ -122,6 +170,16 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ elif args.plugin == 'hybrid_parallel':
+ plugin = HybridParallelPlugin(tp_size=args.tp_size,
+ pp_size=args.pp_size,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ precision='fp16',
+ initial_scale=1)
+ else:
+ raise ValueError(f"Plugin with name {args.plugin} is not supported!")
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare dataloader
@@ -139,6 +197,10 @@ def main():
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
+ # Set criterion (loss function)
+ def criterion(outputs, inputs):
+ return outputs.loss
+
# Set lr scheduler
total_steps = len(train_dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps)
@@ -148,20 +210,21 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
- model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
- optimizer=optimizer,
- dataloader=train_dataloader,
- lr_scheduler=lr_scheduler)
+ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ dataloader=train_dataloader,
+ lr_scheduler=lr_scheduler)
# Finetuning
logger.info(f"Start finetuning", ranks=[0])
for epoch in range(args.num_epoch):
- train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
- evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator)
+ train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator)
+ evaluate_model(epoch, model, criterion, eval_dataloader, booster, coordinator)
logger.info(f"Finish finetuning", ranks=[0])
# Save the finetuned model
- booster.save_model(model, args.output_path)
+ booster.save_model(model, args.output_path, shard=True)
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py
new file mode 100644
index 000000000000..67ff13bb5f5e
--- /dev/null
+++ b/examples/inference/bench_bloom.py
@@ -0,0 +1,100 @@
+import argparse
+import os
+import time
+
+import torch
+from transformers import BloomForCausalLM, BloomTokenizerFast
+
+import colossalai
+from colossalai.inference.tensor_parallel.engine import TPInferEngine
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer import ShardConfig
+from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+
+os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
+
+
+def print_perf_stats(latency_set, config, bs, warmup=3):
+ # trim warmup queries
+ latency_set = list(latency_set)
+ latency_set = latency_set[warmup:]
+ count = len(latency_set)
+
+ if count > 0:
+ latency_set.sort()
+ avg = sum(latency_set) / count
+ num_layers = getattr(config, "num_layers", config.num_hidden_layers)
+ num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
+ num_bytes = 2 # float16
+
+ print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
+ print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
+ print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
+ print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
+
+
+def bench_bloom(args):
+ model_path = args.path
+ max_batch_size = args.batch_size
+ max_input_len = args.input_len
+ max_output_len = args.output_len
+
+ tokenizer = BloomTokenizerFast.from_pretrained(model_path)
+ tokenizer.pad_token = tokenizer.eos_token
+ model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
+ model = model.half()
+
+ # init TPInferEngine and shard the original model
+ # To benchmark torch original, comment out the line of optimizing model
+ shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
+ infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
+
+ # prepare data for generation
+ generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
+ input_tokens = {
+ "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
+ "attention_mask": torch.ones((max_batch_size, max_input_len))
+ }
+ for t in input_tokens:
+ if torch.is_tensor(input_tokens[t]):
+ input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
+ print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")
+
+ iters = 10
+ times = []
+ for i in range(iters):
+ torch.cuda.synchronize()
+ start = time.time()
+ outputs = infer_engine.generate(input_tokens, **generate_kwargs)
+ torch.cuda.synchronize()
+ end = time.time()
+ out_len = outputs.shape[1]
+ print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
+ times.append((end - start) / (out_len - max_input_len))
+
+ print_perf_stats(times, model.config, max_batch_size)
+
+
+def check_bloom(rank, world_size, port, args):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ bench_bloom(args)
+
+
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_bloom(args):
+ spawn(check_bloom, args.tp_size, args=args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
+ parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
+ parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
+ parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
+ parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
+
+ args = parser.parse_args()
+
+ test_bloom(args)
diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py
new file mode 100644
index 000000000000..d2016a4587e6
--- /dev/null
+++ b/examples/inference/bench_llama.py
@@ -0,0 +1,128 @@
+import argparse
+import os
+import time
+
+import torch
+from torch.profiler import ProfilerActivity, profile, record_function
+from transformers import LlamaForCausalLM, LlamaTokenizer
+
+import colossalai
+from colossalai.inference.tensor_parallel.engine import TPInferEngine
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer import ShardConfig
+from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+
+os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
+
+
+def init_to_get_rotary(self, base=10000):
+ self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
+ if not hasattr(self.config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
+ if hasattr(self.config, "max_sequence_length"):
+ max_seq_len = self.config.max_sequence_length
+ elif hasattr(self.config, "max_position_embeddings"):
+ max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+ inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
+ self.config.head_dim_))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
+ self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
+ return
+
+
+def print_perf_stats(latency_set, config, bs, warmup=3):
+ # trim warmup queries
+ latency_set = list(latency_set)
+ latency_set = latency_set[warmup:]
+ count = len(latency_set)
+
+ if count > 0:
+ latency_set.sort()
+ avg = sum(latency_set) / count
+ num_layers = getattr(config, "num_layers", config.num_hidden_layers)
+ num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
+ num_bytes = 2
+
+ print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
+ print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
+ print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
+
+
+def run_llama_test(args):
+ llama_model_path = args.path
+ max_batch_size = args.batch_size
+ max_input_len = args.input_len
+ max_output_len = args.output_len
+
+ tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
+ tokenizer.pad_token_id = tokenizer.unk_token_id
+ model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
+ init_to_get_rotary(model.model, base=10000)
+ model = model.half()
+
+ model_config = model.config
+
+ shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
+ infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
+
+ generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
+ input_tokens = {
+ "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
+ "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
+ }
+
+ iters = 10
+ times = []
+
+ for i in range(iters):
+ torch.cuda.synchronize()
+ start = time.time()
+ outputs = infer_engine.generate(input_tokens, **generate_kwargs)
+ torch.cuda.synchronize()
+ end = time.time()
+ out_len = outputs.shape[1]
+ print("generation time {} s".format(str(end - start)))
+ times.append((end - start) / (out_len - max_input_len))
+
+ print("outputs, ", len(outputs))
+ print_perf_stats(times, model_config, max_batch_size)
+
+ with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
+ with record_function("model_inference"):
+ torch.cuda.synchronize()
+ outputs = infer_engine.generate(input_tokens, **generate_kwargs)
+ torch.cuda.synchronize()
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+
+def check_llama(rank, world_size, port, args):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_llama_test(args)
+
+
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_llama(args):
+ spawn(check_llama, args.tp_size, args=args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
+ parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
+ parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
+ parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
+ parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
+
+ args = parser.parse_args()
+
+ test_llama(args)
diff --git a/examples/language/bert/README.md b/examples/language/bert/README.md
index da38e8375bf0..6601edb7960e 100644
--- a/examples/language/bert/README.md
+++ b/examples/language/bert/README.md
@@ -7,13 +7,15 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
bash test_ci.sh
```
-### Results on 2-GPU
+### Bert-Finetune Results
+
+| Plugin | Accuracy | F1-score | GPU number |
+| -------------- | -------- | -------- | -------- |
+| torch_ddp | 84.4% | 88.6% | 2 |
+| torch_ddp_fp16 | 84.7% | 88.8% | 2 |
+| gemini | 84.0% | 88.4% | 2 |
+| hybrid_parallel | 84.5% | 88.6% | 4 |
-| Plugin | Accuracy | F1-score |
-| -------------- | -------- | -------- |
-| torch_ddp | 84.4% | 88.6% |
-| torch_ddp_fp16 | 84.7% | 88.8% |
-| gemini | 84.0% | 88.4% |
## Benchmark
```
diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py
index 59f10a77c22d..fb6e4332c2f9 100644
--- a/examples/language/bert/finetune.py
+++ b/examples/language/bert/finetune.py
@@ -1,12 +1,14 @@
import argparse
-from typing import List, Union
+from contextlib import nullcontext
+from typing import Callable, List, Union
import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
-from torch.optim import Optimizer
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
@@ -18,8 +20,9 @@
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@@ -32,38 +35,82 @@
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
+output_transform_fn = lambda x: x
+criterion = lambda x: x.loss
+
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
@torch.no_grad()
-def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
- task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
+def evaluate_model(
+ model: nn.Module,
+ criterion,
+ test_dataloader: Union[DataLoader, List[DataLoader]],
+ num_labels: int,
+ task_name: str,
+ eval_splits: List[str],
+ booster: Booster,
+ coordinator: DistCoordinator,
+):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
def evaluate_subset(dataloader: DataLoader):
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
- outputs = model(**batch)
- val_loss, logits = outputs[:2]
- accum_loss.add_(val_loss)
-
- if num_labels > 1:
- preds = torch.argmax(logits, axis=1)
- elif num_labels == 1:
- preds = logits.squeeze()
-
labels = batch["labels"]
-
- metric.add_batch(predictions=preds, references=labels)
+ if use_pipeline:
+ pg_mesh = booster.plugin.pg_mesh
+ pp_group = booster.plugin.pp_group
+ current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
+ current_rank = dist.get_rank()
+ batch = iter([batch])
+ outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
+
+ if is_pp_last_stage:
+ logits = outputs["outputs"]["logits"]
+ val_loss = outputs["loss"]
+ accum_loss.add_(val_loss)
+
+ if num_labels > 1:
+ preds = torch.argmax(logits, axis=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)
+
+ metric.add_batch(predictions=preds, references=labels)
+ elif current_rank in current_pp_group_ranks:
+ object_list = [None, None]
+ dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
+
+ metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
+ accum_loss.add_(object_list[1].to(get_current_device()))
+
+ else:
+ batch = move_to_cuda(batch)
+ outputs = model(**batch)
+ val_loss, logits = outputs[:2]
+ accum_loss.add_(val_loss)
+
+ if num_labels > 1:
+ preds = torch.argmax(logits, axis=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ metric.add_batch(predictions=preds, references=labels)
results = metric.compute()
dist.all_reduce(accum_loss.div_(len(dataloader)))
- if coordinator.is_master():
+ if coordinator.is_master() and results is not None:
results['loss'] = accum_loss.item() / coordinator.world_size
+
return results
if isinstance(test_dataloader, DataLoader):
@@ -77,25 +124,44 @@ def evaluate_subset(dataloader: DataLoader):
return final_results
-def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
- booster: Booster, coordinator: DistCoordinator):
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
+ train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
+ total_step = len(train_dataloader)
+
model.train()
- with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
- for batch in pbar:
- # Forward pass
- batch = move_to_cuda(batch)
- outputs = model(**batch)
- loss = outputs[0]
+ optimizer.zero_grad()
+ train_dataloader_iter = iter(train_dataloader)
+ with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not print_flag) as pbar:
+ # Forward pass
+ for _ in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(train_dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+ else:
+ data = next(train_dataloader_iter)
+ data = move_to_cuda(data)
+ outputs = model(**data)
+ loss = _criterion(outputs, None)
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({'loss': loss.item()})
- # Backward and optimize
- booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
- # Print log info
- pbar.set_postfix({'loss': loss.item()})
-
def main():
# ==============================
@@ -107,7 +173,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
- choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
help="plugin to use")
parser.add_argument(
"--model_type",
@@ -116,6 +182,7 @@ def main():
help="bert or albert",
)
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
+ parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()
if args.model_type == 'bert':
@@ -124,13 +191,13 @@ def main():
model_name = "albert-xxlarge-v2"
else:
raise RuntimeError
+
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
coordinator = DistCoordinator()
- # local_batch_size = BATCH_SIZE // coordinator.world_size
lr = LEARNING_RATE * coordinator.world_size
# ==============================
@@ -145,6 +212,17 @@ def main():
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ elif args.plugin == 'hybrid_parallel':
+
+ # modify the param accordingly for finetuning test cases
+ plugin = HybridParallelPlugin(tp_size=1,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ zero_stage=1,
+ precision='fp16',
+ initial_scale=1)
booster = Booster(plugin=plugin, **booster_kwargs)
@@ -165,8 +243,9 @@ def main():
# bert pretrained model
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
+
if model_name == "bert-base-uncased":
- model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
+ model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
@@ -196,19 +275,27 @@ def main():
num_training_steps=total_steps,
)
+ def _criterion(outputs, inputs):
+ outputs = output_transform_fn(outputs)
+ loss = criterion(outputs)
+ return loss
+
# ==============================
# Boost with ColossalAI
# ==============================
- model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
+ model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
+ optimizer,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
- train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
- results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
- coordinator)
+ results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task,
+ data_builder.eval_splits, booster, coordinator)
if coordinator.is_master():
print(results)
diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh
index 7fc6daabb2f3..394ff831b855 100755
--- a/examples/language/bert/test_ci.sh
+++ b/examples/language/bert/test_ci.sh
@@ -3,6 +3,6 @@ set -xe
pip install -r requirements.txt
-for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
+for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
done
diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md
index 47d24a4d69cb..03679e66404a 100644
--- a/examples/language/gpt/README.md
+++ b/examples/language/gpt/README.md
@@ -65,6 +65,16 @@ Titans provides a customized GPT model, which uses distributed operators as buil
In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP.
You can switch parallel strategies using a config file.
+### Hybridparallelism
+
+Hybridparallelism provides a user friendly plugin to set multiple parallelism method for training and inference. In [./hybridparallelism], we provide a n example to finetune gpt2 using Hybridparallelism.
+
+Quick run
+```bash
+cd ./hybridparallelism
+bash run.sh
+```
+
## Performance
Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e.
diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py
new file mode 100644
index 000000000000..981cedcca8c2
--- /dev/null
+++ b/examples/language/gpt/hybridparallelism/data.py
@@ -0,0 +1,127 @@
+import datasets
+from transformers import AutoTokenizer, PreTrainedTokenizer
+
+from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
+
+
+class GLUEDataBuilder:
+
+ task_text_field_map = {
+ "cola": ["sentence"],
+ "sst2": ["sentence"],
+ "mrpc": ["sentence1", "sentence2"],
+ "qqp": ["question1", "question2"],
+ "stsb": ["sentence1", "sentence2"],
+ "mnli": ["premise", "hypothesis"],
+ "qnli": ["question", "sentence"],
+ "rte": ["sentence1", "sentence2"],
+ "wnli": ["sentence1", "sentence2"],
+ "ax": ["premise", "hypothesis"],
+ }
+
+ glue_task_num_labels = {
+ "cola": 2,
+ "sst2": 2,
+ "mrpc": 2,
+ "qqp": 2,
+ "stsb": 1,
+ "mnli": 3,
+ "qnli": 2,
+ "rte": 2,
+ "wnli": 2,
+ "ax": 3,
+ }
+
+ loader_columns = [
+ "datasets_idx",
+ "input_ids",
+ "token_type_ids",
+ "attention_mask",
+ "start_positions",
+ "end_positions",
+ "labels",
+ ]
+
+ def __init__(
+ self,
+ model_name_or_path: str,
+ plugin: DPPluginBase,
+ task_name: str = "mrpc",
+ max_seq_length: int = 128,
+ train_batch_size: int = 32,
+ eval_batch_size: int = 32,
+ **kwargs,
+ ):
+ super().__init__()
+ self.model_name_or_path = model_name_or_path
+ self.task_name = task_name
+ self.max_seq_length = max_seq_length
+ self.train_batch_size = train_batch_size
+ self.eval_batch_size = eval_batch_size
+ self.plugin = plugin
+
+ self.text_fields = self.task_text_field_map[task_name]
+ self.num_labels = self.glue_task_num_labels[task_name]
+ self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
+ self.setup()
+
+ def setup(self):
+ self.dataset = datasets.load_dataset("glue", self.task_name)
+
+ for split in self.dataset.keys():
+ self.dataset[split] = self.dataset[split].map(
+ self.convert_to_features,
+ batched=True,
+ remove_columns=["label"],
+ )
+ self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
+ self.dataset[split].set_format(type="torch", columns=self.columns)
+
+ self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
+
+ def prepare_data(self):
+ datasets.load_dataset("glue", self.task_name)
+ AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
+
+ def train_dataloader(self):
+ return self.plugin.prepare_dataloader(self.dataset["train"],
+ batch_size=self.train_batch_size,
+ shuffle=True,
+ drop_last=True)
+
+ def val_dataloader(self):
+ if len(self.eval_splits) == 1:
+ return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
+ elif len(self.eval_splits) > 1:
+ return [
+ self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
+ for x in self.eval_splits
+ ]
+
+ def test_dataloader(self):
+ if len(self.eval_splits) == 1:
+ return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
+ elif len(self.eval_splits) > 1:
+ return [
+ self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
+ for x in self.eval_splits
+ ]
+
+ def convert_to_features(self, example_batch):
+
+ # Either encode single sentence or sentence pairs
+ if len(self.text_fields) > 1:
+ texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
+ else:
+ texts_or_text_pairs = example_batch[self.text_fields[0]]
+
+ # Tokenize the text/text pairs
+ features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
+ max_length=self.max_seq_length,
+ padding='max_length',
+ truncation=True)
+
+ # Rename label to labels to make it easier to pass to model forward
+ features["labels"] = example_batch["label"]
+
+ return features
diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py
new file mode 100644
index 000000000000..03e5ec91b3fe
--- /dev/null
+++ b/examples/language/gpt/hybridparallelism/finetune.py
@@ -0,0 +1,299 @@
+import argparse
+from contextlib import nullcontext
+from typing import Callable, List, Union
+
+import evaluate
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from data import GLUEDataBuilder
+from torch.optim import Adam, Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+# ==============================
+# Prepare Hyperparameters
+# ==============================
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+
+output_transform_fn = lambda x: x
+criterion = lambda x: x.loss
+
+
+def move_to_cuda(batch):
+ return {k: v.cuda() for k, v in batch.items()}
+
+
+@torch.no_grad()
+def evaluate_model(
+ model: nn.Module,
+ criterion,
+ test_dataloader: Union[DataLoader, List[DataLoader]],
+ num_labels: int,
+ task_name: str,
+ eval_splits: List[str],
+ booster: Booster,
+ coordinator: DistCoordinator,
+):
+ metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
+ model.eval()
+
+ def evaluate_subset(dataloader: DataLoader):
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+
+ accum_loss = torch.zeros(1, device=get_current_device())
+ for batch in dataloader:
+ batch = move_to_cuda(batch)
+ labels = batch["labels"]
+ if use_pipeline:
+ pg_mesh = booster.plugin.pg_mesh
+ pp_group = booster.plugin.pp_group
+ current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
+ current_rank = dist.get_rank()
+ batch = iter([batch])
+ outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
+
+ if is_pp_last_stage:
+ logits = outputs["outputs"]["logits"]
+ val_loss = outputs["loss"]
+ accum_loss.add_(val_loss)
+
+ if num_labels > 1:
+ preds = torch.argmax(logits, axis=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)
+
+ metric.add_batch(predictions=preds, references=labels)
+ elif current_rank in current_pp_group_ranks:
+ object_list = [None, None]
+ dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
+
+ metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
+ accum_loss.add_(object_list[1].to(get_current_device()))
+
+ else:
+ batch = move_to_cuda(batch)
+ outputs = model(**batch)
+ val_loss, logits = outputs[:2]
+ accum_loss.add_(val_loss)
+
+ if num_labels > 1:
+ preds = torch.argmax(logits, axis=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ metric.add_batch(predictions=preds, references=labels)
+
+ results = metric.compute()
+ dist.all_reduce(accum_loss.div_(len(dataloader)))
+ if coordinator.is_master() and results is not None:
+ results['loss'] = accum_loss.item() / coordinator.world_size
+
+ return results
+
+ if isinstance(test_dataloader, DataLoader):
+ return evaluate_subset(test_dataloader)
+ else:
+ assert len(test_dataloader) == len(eval_splits)
+ final_results = {}
+ for split, sub_loader in zip(eval_splits, test_dataloader):
+ results = evaluate_subset(sub_loader)
+ final_results.update({f'{k}_{split}': v for k, v in results.items()})
+ return final_results
+
+
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
+ train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
+
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ total_step = len(train_dataloader)
+
+ model.train()
+ optimizer.zero_grad()
+ train_dataloader_iter = iter(train_dataloader)
+ with tqdm(range(total_step),
+ desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
+ disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
+ # Forward pass
+ for _ in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(train_dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+ else:
+ data = next(train_dataloader_iter)
+ data = move_to_cuda(data)
+ outputs = model(**data)
+ loss = _criterion(outputs, None)
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({'loss': loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+
+def main():
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
+ parser.add_argument('-p',
+ '--plugin',
+ type=str,
+ default='torch_ddp',
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
+ help="plugin to use")
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ default="gpt2",
+ help="only gpt2 now",
+ )
+ parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
+ parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
+ args = parser.parse_args()
+
+ if args.model_type == 'gpt2':
+ model_name = "gpt2"
+ else:
+ raise RuntimeError
+ # ==============================
+ # Launch Distributed Environment
+ # ==============================
+ colossalai.launch_from_torch(config={}, seed=42)
+ coordinator = DistCoordinator()
+
+ # local_batch_size = BATCH_SIZE // coordinator.world_size
+ lr = LEARNING_RATE * coordinator.world_size
+
+ # ==============================
+ # Instantiate Plugin and Booster
+ # ==============================
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ elif args.plugin == 'hybrid_parallel':
+
+ # modify the param accordingly for finetuning test cases
+ plugin = HybridParallelPlugin(tp_size=1,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_all_optimization=True,
+ zero_stage=1,
+ precision='fp16',
+ initial_scale=1)
+
+ booster = Booster(plugin=plugin, **booster_kwargs)
+
+ # ==============================
+ # Prepare Dataloader
+ # ==============================
+ data_builder = GLUEDataBuilder(model_name,
+ plugin,
+ args.task,
+ train_batch_size=BATCH_SIZE,
+ eval_batch_size=BATCH_SIZE)
+ train_dataloader = data_builder.train_dataloader()
+ test_dataloader = data_builder.test_dataloader()
+
+ # ====================================
+ # Prepare model, optimizer
+ # ====================================
+ # gpt2 pretrained model
+
+ cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
+
+ if model_name == "gpt2":
+ model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
+ else:
+ raise RuntimeError
+
+ # optimizer
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
+
+ # lr scheduler
+ total_steps = len(train_dataloader) * NUM_EPOCHS
+ num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+ )
+
+ def _criterion(outputs, inputs):
+ outputs = output_transform_fn(outputs)
+ loss = criterion(outputs)
+ return loss
+
+ # ==============================
+ # Boost with ColossalAI
+ # ==============================
+ model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
+ optimizer,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
+
+ # ==============================
+ # Train model
+ # ==============================
+ for epoch in range(NUM_EPOCHS):
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
+
+ results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task,
+ data_builder.eval_splits, booster, coordinator)
+
+ if coordinator.is_master():
+ print(results)
+ if args.target_f1 is not None and 'f1' in results:
+ assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/gpt/hybridparallelism/run.sh b/examples/language/gpt/hybridparallelism/run.sh
new file mode 100644
index 000000000000..679cbbf9b1e2
--- /dev/null
+++ b/examples/language/gpt/hybridparallelism/run.sh
@@ -0,0 +1,5 @@
+# load via internet
+torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2"
+
+# load from local
+# torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2" --pretrained_path "your/path/to/pretrained_model"
diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt
index ef58bb76bfc8..1a173f228aee 100644
--- a/examples/language/gpt/requirements.txt
+++ b/examples/language/gpt/requirements.txt
@@ -1,2 +1,7 @@
transformers >= 4.23
colossalai
+evaluate
+tqdm
+scipy
+scikit-learn
+numpy
diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh
index d67c17229e71..b9e4e43a8d35 100644
--- a/examples/language/gpt/test_ci.sh
+++ b/examples/language/gpt/test_ci.sh
@@ -1,2 +1,5 @@
set -x
+pip install -r requirements.txt
+
cd gemini && bash test_ci.sh
+cd ../hybridparallelism && bash run.sh
diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py
index 64f5944a97f9..fdfc57e9ba22 100644
--- a/examples/language/gpt/titans/dataset/webtext.py
+++ b/examples/language/gpt/titans/dataset/webtext.py
@@ -6,7 +6,7 @@
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer
-from colossalai.registry import DATASETS
+from colossalai.legacy.registry import DATASETS
@DATASETS.register_module
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
index d825ae92a285..e521193a97da 100644
--- a/examples/language/gpt/titans/model/embed.py
+++ b/examples/language/gpt/titans/model/embed.py
@@ -8,11 +8,11 @@
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.base_layer import ParallelLayer
-from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
-from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
-from colossalai.nn.layer.utils import divide
-from colossalai.registry import LAYERS, LOSSES, MODELS
+from colossalai.legacy.nn.layer.base_layer import ParallelLayer
+from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
+from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
+from colossalai.legacy.nn.layer.utils import divide
+from colossalai.legacy.registry import LAYERS, LOSSES, MODELS
from colossalai.utils import get_current_device
diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py
index 2edd03606b7d..72297c540da1 100644
--- a/examples/language/gpt/titans/model/gpt1d.py
+++ b/examples/language/gpt/titans/model/gpt1d.py
@@ -11,9 +11,9 @@
from colossalai import nn as col_nn
from colossalai.core import global_context as gpc
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
-from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
-from colossalai.nn.layer.base_layer import ParallelLayer
-from colossalai.nn.layer.utils import ACT2FN, divide
+from colossalai.legacy.nn.layer import Linear1D_Col, Linear1D_Row
+from colossalai.legacy.nn.layer.base_layer import ParallelLayer
+from colossalai.legacy.nn.layer.utils import ACT2FN, divide
from colossalai.utils import checkpoint
from colossalai.utils.activation_checkpoint import checkpoint
diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py
index 30180285bc70..9b22d156bbcd 100644
--- a/examples/language/gpt/titans/model/pipeline_gpt1d.py
+++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py
@@ -9,8 +9,8 @@
from colossalai import nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.logging import get_dist_logger
-from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.pipeline.utils import partition_uniform
from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py
index 6be0b9e8da30..b239b626c07f 100644
--- a/examples/language/gpt/titans/train_gpt.py
+++ b/examples/language/gpt/titans/train_gpt.py
@@ -10,9 +10,9 @@
import colossalai.utils as utils
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.trainer import Trainer, hooks
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
-from colossalai.trainer import Trainer, hooks
from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
from colossalai.utils.timer import MultiTimer
from colossalai.zero.legacy.init_ctx import ZeroInitContext
diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md
index b64b5d29ecb8..83ef99b57d42 100644
--- a/examples/language/llama2/README.md
+++ b/examples/language/llama2/README.md
@@ -1,4 +1,22 @@
-# Pretraining LLaMA-2: best practices for building LLaMA-2-like base models
+# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models
+
+### LLaMA2
+
+
+
+
+- 70 billion parameter LLaMA2 model training accelerated by 195%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
+
+### LLaMA1
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
## Dataset
@@ -73,8 +91,8 @@ Make sure master node can access all nodes (including itself) by ssh without pas
Here is details about CLI arguments:
-- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported.
-- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
+- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
+- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
@@ -105,7 +123,7 @@ Here we will show an example of how to run training
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
#### a. Running environment
-This experiment was performed on 4 computing nodes with 32 A800 GPUs in total. The nodes are
+This experiment was performed on 4 computing nodes with 32 A800 GPUs in total for LLaMA-1 65B. The nodes are
connected with RDMA and GPUs within one node are fully connected with NVLink.
#### b. Running command
@@ -131,6 +149,9 @@ Finally, run the following command to start training:
```bash
bash gemini.sh
```
+
+If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU.
+
#### c. Results
If you run the above command successfully, you will get the following results:
`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
@@ -174,3 +195,40 @@ If you run the above command successfully, you will get the following results:
year={2023}
}
```
+
+
+# Fine-tune Llama2
+
+We also provide a example to fine-tune llama2 in `finetune.py`,
+
+Make sure master node can access all nodes (including itself) by ssh without password.
+
+Here is details about CLI arguments:
+
+- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag.
+- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
+- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`.
+- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`.
+- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
+- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
+- Learning rate: `--lr`. The default value is 3e-4.
+- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
+- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
+- Max length: `-l`, `--max_length`. The default value is 4096.
+- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
+- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
+- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
+- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
+- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
+- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
+- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
+
+
+```shell
+torchrun --standalone --nproc_per_node 8 finetune.py \
+ --plugin "hybrid_parallel" \
+ --dataset "yizhongw/self_instruct" \
+ --model_path "/path/llama" \
+ --task_name "super_natural_instructions" \
+ --save_dir "/path/output"
+```
diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py
new file mode 100644
index 000000000000..0efbf193c9a9
--- /dev/null
+++ b/examples/language/llama2/finetune.py
@@ -0,0 +1,295 @@
+import argparse
+import math
+import os
+import resource
+from contextlib import nullcontext
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from attn import SUPPORT_XFORMERS, replace_xformers
+from data_utils import load_json, prepare_dataloader, save_json
+from datasets import load_dataset
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+def get_model_numel(model: nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def format_numel_str(numel: int) -> str:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ if numel >= B:
+ return f'{numel / B:.2f} B'
+ elif numel >= M:
+ return f'{numel / M:.2f} M'
+ elif numel >= K:
+ return f'{numel / K:.2f} K'
+ else:
+ return f'{numel}'
+
+
+def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
+ texts = [sample['prompt'] + sample['completion'] for sample in batch]
+ data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
+ data = {k: v.cuda() for k, v in data.items()}
+ data['labels'] = data['input_ids'].clone()
+ return data
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+ tensor.div_(dist.get_world_size())
+ return tensor
+
+
+def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
+ batch_size: int, coordinator: DistCoordinator, save_dir: str):
+ save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
+ os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
+
+ booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
+ booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
+ running_states = {
+ 'epoch': epoch,
+ 'step': step,
+ 'sample_start_index': step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, 'running_states.json'))
+
+
+def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
+ load_dir: str) -> Tuple[int, int, int]:
+ booster.load_model(model, os.path.join(load_dir, 'model'))
+ booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
+ booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
+ running_states = load_json(os.path.join(load_dir, 'running_states.json'))
+ return running_states['epoch'], running_states['step'], running_states['sample_start_index']
+
+
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+
+def main():
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune")
+ parser.add_argument('-p',
+ '--plugin',
+ choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
+ default='gemini',
+ help='Choose which plugin to use')
+ parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path')
+ parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run')
+ parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
+ parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
+ parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
+ parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
+ parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
+ parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
+ parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
+ parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
+ parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
+ parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
+ parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
+ parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
+ parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
+ args = parser.parse_args()
+
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == 'gemini':
+ plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
+ elif args.plugin == 'gemini_auto':
+ plugin = GeminiPlugin(precision=args.mixed_precision,
+ placement_policy='auto',
+ initial_scale=2**16,
+ max_norm=args.grad_clip)
+ elif args.plugin == 'zero2':
+ plugin = LowLevelZeroPlugin(stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip)
+ elif args.plugin == 'zero2_cpu':
+ plugin = LowLevelZeroPlugin(stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip)
+ elif args.plugin == 'hybrid_parallel':
+ # modify the param accordingly, default configuration is for llama2-7b
+ plugin = HybridParallelPlugin(tp_size=4,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_jit_fused=False,
+ zero_stage=0,
+ precision='fp32',
+ initial_scale=1)
+ else:
+ raise ValueError(f'Unknown plugin {args.plugin}')
+
+ booster = Booster(plugin=plugin)
+
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
+
+ # ==============================
+ # Initialize Tensorboard
+ # ==============================
+ if print_flag:
+ os.makedirs(args.tensorboard_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+
+ # ==============================
+ # Initialize Model, Optimizer and LR Scheduler
+ # ==============================
+
+ config = LlamaConfig.from_pretrained(args.model_path)
+ # use lazy init when using GeminiPlugin
+ init_ctx = LazyInitContext(
+ default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
+
+ with init_ctx:
+ model = LlamaForCausalLM(config)
+
+ # ==============================
+ # Initialize Tokenizer, Dataset and Dataloader
+ # ==============================
+ tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
+ # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
+ tokenizer.pad_token = tokenizer.unk_token
+
+ dataset = load_dataset(args.dataset, args.task_name)
+ train_ds = dataset['train']
+ dataloader = prepare_dataloader(train_ds,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=partial(tokenize_batch_for_finetune,
+ tokenizer=tokenizer,
+ max_length=args.max_length))
+
+ if args.grad_checkpoint:
+ model.gradient_checkpointing_enable()
+ if args.flash_attention:
+ assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
+ replace_xformers(model)
+
+ model_numel = get_model_numel(model)
+ coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
+
+ optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
+ total_step = args.num_epochs * len(dataloader)
+ lr_scheduler = CosineAnnealingWarmupLR(optimizer,
+ total_steps=total_step,
+ warmup_steps=math.ceil(total_step * 0.03),
+ eta_min=0.1 * args.lr)
+ default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
+ optimizer,
+ dataloader=dataloader,
+ lr_scheduler=lr_scheduler)
+ torch.set_default_dtype(torch.float)
+
+ booster.load_model(model, args.model_path)
+
+ coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
+ coordinator.print_on_master(
+ f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
+
+ # load checkpoint if specified
+ start_epoch = 0
+ start_step = 0
+ sampler_start_idx = 0
+ if args.load is not None:
+ coordinator.print_on_master('Loading checkpoint')
+ start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
+ coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
+
+ num_steps_per_epoch = len(dataloader)
+
+ # if resume training, set the sampler start index to the correct value
+ dataloader.sampler.set_start_index(sampler_start_idx)
+ for epoch in range(start_epoch, args.num_epochs):
+ dataloader.sampler.set_epoch(epoch)
+ step_nums = num_steps_per_epoch - start_step
+ dataloader_iter = iter(dataloader)
+
+ with tqdm(range(step_nums),
+ desc=f'Epoch {epoch}',
+ disable=not print_flag,
+ total=num_steps_per_epoch,
+ initial=start_step) as pbar:
+ for step in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ loss = outputs["loss"]
+ else:
+ batch = next(dataloader_iter)
+ outputs = model(**batch)
+ loss = outputs[0]
+ booster.backward(loss, optimizer)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ if not use_pipeline:
+ all_reduce_mean(loss)
+ if print_flag:
+ pbar.set_postfix({'loss': loss.item()})
+ writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
+
+ if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
+ coordinator.print_on_master(f'Saving checkpoint')
+ save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
+ args.save_dir)
+ coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
+ # the continue epochs are not resumed, so we need to reset the sampler start index and start step
+ dataloader.sampler.set_start_index(0)
+ start_step = 0
+
+ coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py
index b72a3019692e..0eeac4035401 100644
--- a/examples/language/llama2/pretrain.py
+++ b/examples/language/llama2/pretrain.py
@@ -21,7 +21,7 @@
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -65,9 +65,10 @@ def format_numel_str(numel: int) -> str:
return f'{numel}'
-def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
+def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample['text'] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
+ data = {k: v.cuda() for k, v in data.items()}
data['labels'] = data['input_ids'].clone()
return data
@@ -104,6 +105,10 @@ def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler:
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
+def _criterion(outputs, inputs):
+ return outputs.loss
+
+
def main():
# ==============================
# Parse Arguments
@@ -112,7 +117,7 @@ def main():
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
parser.add_argument('-p',
'--plugin',
- choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'],
+ choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
default='gemini',
help='Choose which plugin to use')
parser.add_argument('-d',
@@ -142,13 +147,6 @@ def main():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
- # ==============================
- # Initialize Tensorboard
- # ==============================
- if coordinator.is_master():
- os.makedirs(args.tensorboard_dir, exist_ok=True)
- writer = SummaryWriter(args.tensorboard_dir)
-
# ==============================
# Initialize Booster
# ==============================
@@ -170,11 +168,32 @@ def main():
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip)
+ elif args.plugin == 'hybrid_parallel':
+ # modify the param accordingly, default configuration is for llama2-7b
+ plugin = HybridParallelPlugin(tp_size=4,
+ pp_size=2,
+ num_microbatches=None,
+ microbatch_size=1,
+ enable_jit_fused=False,
+ zero_stage=0,
+ precision='fp32',
+ initial_scale=1)
else:
raise ValueError(f'Unknown plugin {args.plugin}')
booster = Booster(plugin=plugin)
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
+
+ # ==============================
+ # Initialize Tensorboard
+ # ==============================
+ if print_flag:
+ os.makedirs(args.tensorboard_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+
# ==============================
# Initialize Tokenizer, Dataset and Dataloader
# ==============================
@@ -188,12 +207,15 @@ def main():
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
- collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length))
+ collate_fn=partial(tokenize_batch_for_pretrain,
+ tokenizer=tokenizer,
+ max_length=args.max_length))
# ==============================
# Initialize Model, Optimizer and LR Scheduler
# ==============================
config = MODEL_CONFIGS[args.config]
+ # use lazy init when using GeminiPlugin
init_ctx = LazyInitContext(
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
@@ -236,27 +258,42 @@ def main():
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
num_steps_per_epoch = len(dataloader)
+
# if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
- with tqdm(enumerate(dataloader),
+ step_nums = num_steps_per_epoch - start_step
+ dataloader_iter = iter(dataloader)
+
+ with tqdm(range(step_nums),
desc=f'Epoch {epoch}',
- disable=not coordinator.is_master(),
+ disable=not print_flag,
total=num_steps_per_epoch,
initial=start_step) as pbar:
- for step, batch in pbar:
- batch = {k: v.cuda() for k, v in batch.items()}
- outputs = model(**batch)
- loss = outputs[0]
- booster.backward(loss, optimizer)
+ for step in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(dataloader_iter,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ loss = outputs["loss"]
+ else:
+ batch = next(dataloader_iter)
+ outputs = model(**batch)
+ loss = outputs[0]
+ booster.backward(loss, optimizer)
+
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
- all_reduce_mean(loss)
- pbar.set_postfix({'loss': loss.item()})
- if coordinator.is_master():
+ if not use_pipeline:
+ all_reduce_mean(loss)
+ if print_flag:
+ pbar.set_postfix({'loss': loss.item()})
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt
index 3ddf21ffe534..6b475682dad0 100644
--- a/examples/language/llama2/requirements.txt
+++ b/examples/language/llama2/requirements.txt
@@ -1,4 +1,4 @@
-colossalai>=0.3.0
+colossalai>=0.3.2
datasets
numpy
torch>=1.12.0,<=2.0.0
diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md
index 37e1ff4d9008..af1e794374ed 100644
--- a/examples/language/opt/README.md
+++ b/examples/language/opt/README.md
@@ -23,9 +23,9 @@ The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI)
## Our Modifications
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
-the tokenization).
+the tokenization).
-We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
+We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, HybridParallelPlugin and GeminiPlugin.
## Run Demo
@@ -48,6 +48,3 @@ You can run benchmark for OPT model by running the following script:
bash run_benchmark.sh
```
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing.
-
-
-
diff --git a/examples/language/opt/args.py b/examples/language/opt/args.py
index 16730be7ebea..77fa12bc8a0c 100644
--- a/examples/language/opt/args.py
+++ b/examples/language/opt/args.py
@@ -4,117 +4,65 @@
def parse_demo_args():
parser = get_default_parser()
- parser.add_argument(
- "--model_name_or_path",
- type=str,
- default="facebook/opt-350m",
- help="Path to pretrained model or model identifier from huggingface.co/models."
- )
- parser.add_argument(
- "--output_path",
- type=str,
- default="./output_model.bin",
- help="The path of your saved model after finetuning."
- )
+ parser.add_argument("--model_name_or_path",
+ type=str,
+ default="facebook/opt-350m",
+ help="Path to pretrained model or model identifier from huggingface.co/models.")
+ parser.add_argument("--output_path",
+ type=str,
+ default="./output_model.bin",
+ help="The path of your saved model after finetuning.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
- help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
- )
- parser.add_argument(
- "--num_epoch",
- type=int,
- default=10,
- help="Number of epochs."
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=32,
- help="Batch size (per dp group) for the training dataloader."
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=5e-5,
- help="Initial learning rate (after the potential warmup period) to use."
- )
- parser.add_argument(
- "--warmup_ratio",
- type=float,
- default=0.1,
- help="Ratio of warmup steps against total training steps."
- )
- parser.add_argument(
- "--weight_decay",
- type=float,
- default=0.01,
- help="Weight decay to use."
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="A seed for reproducible training."
- )
+ help=
+ "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
+ )
+ parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.")
+ parser.add_argument("--batch_size",
+ type=int,
+ default=32,
+ help="Batch size (per dp group) for the training dataloader.")
+ parser.add_argument("--learning_rate",
+ type=float,
+ default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use.")
+ parser.add_argument("--warmup_ratio",
+ type=float,
+ default=0.1,
+ help="Ratio of warmup steps against total training steps.")
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
args = parser.parse_args()
return args
-
def parse_benchmark_args():
parser = get_default_parser()
- parser.add_argument(
- "--model_name_or_path",
- type=str,
- default="facebook/opt-125m",
- help="Path to pretrained model or model identifier from huggingface.co/models."
- )
+ parser.add_argument("--model_name_or_path",
+ type=str,
+ default="facebook/opt-125m",
+ help="Path to pretrained model or model identifier from huggingface.co/models.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
- help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=32,
- help="Batch size (per dp group) for the training dataloader."
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=5e-5,
- help="Initial learning rate (after the potential warmup period) to use."
- )
- parser.add_argument(
- "--weight_decay",
- type=float,
- default=0.0,
- help="Weight decay to use."
- )
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=20,
- help="Total number of training steps to perform."
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="A seed for reproducible training."
- )
- parser.add_argument(
- "--mem_cap",
- type=int,
- default=0,
- help="Limit on the usage of space for each GPU (in GB)."
- )
+ help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'.")
+ parser.add_argument("--batch_size",
+ type=int,
+ default=32,
+ help="Batch size (per dp group) for the training dataloader.")
+ parser.add_argument("--learning_rate",
+ type=float,
+ default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use.")
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+ parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.")
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).")
args = parser.parse_args()
- return args
\ No newline at end of file
+ return args
diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py
index 80063407ecd5..7d6bdfb9f31c 100644
--- a/examples/language/opt/opt_train_demo.py
+++ b/examples/language/opt/opt_train_demo.py
@@ -11,7 +11,8 @@
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -19,35 +20,54 @@
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
+output_transform_fn = lambda x: x
+criterion = lambda x: x.loss
+
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
-def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
+def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize()
- model.train()
-
- with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
- for batch in pbar:
-
- # Forward
- optimizer.zero_grad()
- batch = move_to_cuda(batch, torch.cuda.current_device())
+ use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ total_step = len(dataloader)
- outputs = model(use_cache=False, **batch)
- loss = outputs['loss']
+ model.train()
+ optimizer.zero_grad()
+ dataloader = iter(dataloader)
+ with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}]',
+ disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
+
+ # Forward pass
+ for _ in pbar:
+ if use_pipeline:
+ outputs = booster.execute_pipeline(dataloader,
+ model,
+ _criterion,
+ optimizer,
+ return_loss=True,
+ return_outputs=True)
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs['loss']
+ pbar.set_postfix({'loss': loss.item()})
+ else:
+ data = next(dataloader)
+ data = move_to_cuda(data)
+ outputs = model(**data)
+ loss = _criterion(outputs, None)
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({'loss': loss.item()})
- # Backward
- booster.backward(loss, optimizer)
optimizer.step()
+ optimizer.zero_grad()
lr_scheduler.step()
- # Print batch loss
- pbar.set_postfix({'loss': loss.item()})
-
def main():
@@ -86,6 +106,16 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ elif args.plugin == 'hybrid_parallel':
+ # modify the param accordingly for finetuning test cases
+ plugin = HybridParallelPlugin(tp_size=2,
+ pp_size=2,
+ num_microbatches=2,
+ enable_all_optimization=True,
+ zero_stage=0,
+ precision='fp16',
+ initial_scale=1)
+
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare tokenizer and dataloader
@@ -107,21 +137,28 @@ def main():
num_warmup_steps=num_warmup_steps,
num_training_steps=len(dataloader) * args.num_epoch)
+ # Define criterion
+ def _criterion(outputs, inputs):
+ outputs = output_transform_fn(outputs)
+ loss = criterion(outputs)
+ return loss
+
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
- model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
- optimizer=optimizer,
- dataloader=dataloader,
- lr_scheduler=lr_scheduler)
+ model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost(model=model,
+ optimizer=optimizer,
+ dataloader=dataloader,
+ criterion=_criterion,
+ lr_scheduler=lr_scheduler)
# Start finetuning
logger.info(f"Start finetuning", ranks=[0])
for epoch in range(args.num_epoch):
- train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator)
+ train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator)
# Finish training and evaluate
logger.info(f"Finish finetuning", ranks=[0])
- booster.save_model(model, args.output_path)
+ booster.save_model(model, args.output_path, shard=True)
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt
index 4422216e6a1c..45bfbc37195f 100644
--- a/examples/language/opt/requirements.txt
+++ b/examples/language/opt/requirements.txt
@@ -1,4 +1,4 @@
-colossalai >= 0.1.12
+colossalai >= 0.3.2
torch >= 1.8.1
datasets >= 1.8.0
-transformers >= 4.20.0
\ No newline at end of file
+transformers >= 4.30.2
diff --git a/examples/language/opt/run_demo.sh b/examples/language/opt/run_demo.sh
index 0c9759c34039..07b429cecf1e 100644
--- a/examples/language/opt/run_demo.sh
+++ b/examples/language/opt/run_demo.sh
@@ -9,7 +9,7 @@ OUTPUT_PATH="./output_model.bin"
# plugin(training strategy)
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
-PLUGIN="gemini"
+PLUGIN="hybrid_parallel"
# number of gpus to use
GPUNUM=4
diff --git a/examples/tutorial/hybrid_parallel/test_ci.sh b/examples/tutorial/hybrid_parallel/test_ci.sh
index e0dbef354e2d..24cee1da3de4 100644
--- a/examples/tutorial/hybrid_parallel/test_ci.sh
+++ b/examples/tutorial/hybrid_parallel/test_ci.sh
@@ -1,5 +1,7 @@
#!/bin/bash
set -euxo pipefail
-pip install -r requirements.txt
-colossalai run --nproc_per_node 4 train.py --config config.py
+echo "legacy example"
+
+# pip install -r requirements.txt
+# colossalai run --nproc_per_node 4 train.py --config config.py
diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py
index 4953d5350f31..12cdec902400 100644
--- a/examples/tutorial/hybrid_parallel/train.py
+++ b/examples/tutorial/hybrid_parallel/train.py
@@ -7,8 +7,8 @@
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.nn import CrossEntropyLoss
from colossalai.logging import get_dist_logger
-from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.utils import is_using_pp
diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py
index b4febcd822e1..9a25dc453c24 100644
--- a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py
+++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py
@@ -3,17 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
-from functools import lru_cache
import os
import shutil
import struct
+from functools import lru_cache
from itertools import accumulate
import numpy as np
@@ -88,16 +87,7 @@ def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
-dtypes = {
- 1: np.uint8,
- 2: np.int8,
- 3: np.int16,
- 4: np.int32,
- 5: np.int64,
- 6: np.float,
- 7: np.double,
- 8: np.uint16
-}
+dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: float, 7: np.double, 8: np.uint16}
def code(dtype):
@@ -136,10 +126,8 @@ def __init__(self, path):
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
- assert magic == self._HDR_MAGIC, (
- 'Index file doesn\'t match expected format. '
- 'Make sure that --dataset-impl is configured properly.'
- )
+ assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
+ 'Make sure that --dataset-impl is configured properly.')
version = f.read(8)
assert struct.unpack(' version.parse('11.5')
+
+
+@parameterize('test_config', [{
+ 'tp_size': TP_SIZE,
+}])
+def run(test_config):
+
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom_for_causal_lm')
+ for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
+ orig_model = model_fn()
+ orig_model = orig_model.half()
+ data = data_gen_fn()
+
+ shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
+ inference_only=True)
+ infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
+
+ generate_kwargs = dict(do_sample=False)
+ outputs = infer_engine.generate(data, **generate_kwargs)
+
+ assert outputs is not None
+
+
+def check_bloom(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run()
+
+
+@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_bloom_infer():
+ spawn(check_bloom, TP_SIZE)
+
+
+if __name__ == '__main__':
+ test_bloom_infer()
diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py
new file mode 100644
index 000000000000..cc3cdd2b501b
--- /dev/null
+++ b/tests/test_infer/test_infer_engine.py
@@ -0,0 +1,94 @@
+from itertools import accumulate
+
+import pytest
+import torch
+import torch.nn as nn
+from packaging import version
+from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
+from transformers.tokenization_utils_base import BatchEncoding
+
+import colossalai
+from colossalai.inference.tensor_parallel import TPInferEngine
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer import ShardConfig
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+
+TP_SIZE = 2
+MAX_BATCH_SIZE = 4
+MAX_INPUT_LEN = 16
+MAX_OUTPUT_LEN = 8
+
+CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
+
+
+@parameterize('test_config', [{
+ 'tp_size': TP_SIZE,
+}])
+def run(test_config):
+ model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
+ model = BloomForCausalLM(model_config)
+ model = model.half()
+ model.to(torch.cuda.current_device())
+
+ # 1. check TPInferEngine init and model optimization
+ shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
+ inference_only=True)
+ infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
+
+ assert infer_engine.cache_manager is not None
+ assert infer_engine.tp_size == TP_SIZE
+ assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE
+
+ # 2. check data preparation
+ input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
+ [80540, 15473, 3331, 11970], [80540, 15473]]
+ batch_size = len(input_ids_list)
+ max_seq_len = max(len(li) for li in input_ids_list)
+ attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
+ for i, li in enumerate(input_ids_list):
+ attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))]
+ data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
+ inputs_batch_encoding = BatchEncoding(data=data)
+ seq_lengths = [len(li) for li in input_ids_list]
+ start_loc = list(accumulate([0] + seq_lengths[:-1]))
+ seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
+ start_loc = torch.tensor(start_loc, dtype=torch.int32)
+ # input token id list as inputs
+ batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
+ # BatchEncoding as inputs
+ batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)
+
+ assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
+ assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)
+
+ # The following tests are discarded for now, and will be reused after all features are added
+ # assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
+ # assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
+ # assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
+ # assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)
+
+ # 3. check optimized model generate
+ input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))
+ generate_kwargs = dict(do_sample=False)
+ infer_engine.generate(input_ids, **generate_kwargs)
+
+ torch.cuda.empty_cache()
+
+
+def check_engine(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run()
+
+
+@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_engine():
+ spawn(check_engine, TP_SIZE)
+
+
+if __name__ == '__main__':
+ test_engine()
diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py
new file mode 100644
index 000000000000..f57c6956f817
--- /dev/null
+++ b/tests/test_infer/test_kvcache_manager.py
@@ -0,0 +1,61 @@
+import os
+from packaging import version
+import pytest
+import torch
+
+from colossalai.inference.tensor_parallel import MemoryManager
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+
+BATCH_SIZE = 4
+INPUT_LEN = 16
+OUTPUT_LEN = 8
+LAYER_NUM = 4
+HEAD_NUM = 32
+HEAD_DIM = 128
+
+CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
+
+def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
+ os.environ['RANK'] = str(rank)
+ os.environ['LOCAL_RANK'] = str(rank)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = str(port)
+ disable_existing_loggers()
+
+ size = batch_size * (input_len + output_len)
+ kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
+ key_buffers = kvcache_manager.key_buffer
+ value_buffers = kvcache_manager.value_buffer
+ assert len(key_buffers) == len(value_buffers) == layer_num
+ assert key_buffers[0].shape == value_buffers[0].shape
+ # required size exceeds the maximum allocated size
+ invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
+ assert invalid_locs is None
+ # for prefill stage, allocation via alloc and alloc_contiguous should be the same
+ total_token_prefill = batch_size * input_len
+ prefill_locs = kvcache_manager.alloc(total_token_prefill)
+ kvcache_manager.free_all()
+ prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
+ assert torch.equal(prefill_locs, prefill_locs_contiguous)
+ assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
+ kvcache_manager.alloc_contiguous(batch_size)
+ assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)
+
+@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_cache_manager_dist():
+ spawn(create_cache_manager,
+ 4,
+ batch_size=BATCH_SIZE,
+ input_len=INPUT_LEN,
+ output_len=OUTPUT_LEN,
+ layer_num=LAYER_NUM,
+ head_num=HEAD_NUM,
+ head_dim=HEAD_DIM)
+
+
+if __name__ == '__main__':
+ test_cache_manager_dist()
diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py
new file mode 100644
index 000000000000..aa8874ea4cb0
--- /dev/null
+++ b/tests/test_infer/test_llama_infer.py
@@ -0,0 +1,84 @@
+import os
+import warnings
+
+import pytest
+import torch
+from packaging import version
+
+import colossalai
+from colossalai.inference.tensor_parallel.engine import TPInferEngine
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer import ShardConfig
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+from tests.kit.model_zoo import model_zoo
+
+os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
+TPSIZE = 2
+BATCH_SIZE = 8
+MAX_INPUT_LEN = 12
+MAX_OUTPUT_LEN = 100
+
+CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
+
+
+def init_to_get_rotary(self, base=10000):
+ self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
+ if not hasattr(self.config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
+ if hasattr(self.config, "max_sequence_length"):
+ max_seq_len = self.config.max_sequence_length
+ elif hasattr(self.config, "max_position_embeddings"):
+ max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+ inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
+ self.config.head_dim_))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
+ self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
+ return
+
+
+@parameterize('test_config', [{
+ 'tp_size': TPSIZE,
+}])
+def run_llama_test(test_config):
+
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm')
+ for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
+ orig_model = model_fn()
+ init_to_get_rotary(orig_model.model, base=10000)
+ orig_model = orig_model.half()
+ data = data_gen_fn()
+
+ shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
+ inference_only=True)
+ infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
+
+ generate_kwargs = dict(do_sample=False)
+ outputs = infer_engine.generate(data, **generate_kwargs)
+
+ assert outputs is not None
+
+
+def check_llama(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_llama_test()
+
+
+@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_llama():
+ spawn(check_llama, TPSIZE)
+
+
+if __name__ == "__main__":
+ test_llama()
diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py
new file mode 100644
index 000000000000..cb12faf6276c
--- /dev/null
+++ b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+import os
+import pytest
+import numpy as np
+from packaging import version
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+try:
+ from vllm import layernorm_ops
+ rms_norm = layernorm_ops.rms_norm
+ HAS_VLLM_KERNERL = True
+except:
+ print("please install vllm kernels to install rmsnorm")
+ print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
+ HAS_VLLM_KERNERL = False
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
+ x = hidden_states
+ out = torch.empty_like(x)
+ rms_norm(
+ out,
+ x,
+ weight,
+ variance_epsilon,
+ )
+ return out
+
+@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
+def test_rmsnorm():
+ data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
+ hg_rms = LlamaRMSNorm(64)
+ hg_rms = hg_rms.half().cuda()
+ out_torch = hg_rms(data)
+ out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
+
+ check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
+ assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
+
+if __name__ == "__main__":
+ test_rmsnorm()
\ No newline at end of file
diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py
new file mode 100644
index 000000000000..2a85566c65c6
--- /dev/null
+++ b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+import pytest
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
+
+try:
+ from vllm import pos_encoding_ops
+ rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
+ HAS_VLLM_KERNERL = True
+except:
+ print("fall back to original rotary_embedding_neox of huggingface")
+ print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
+ HAS_VLLM_KERNERL = False
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ x1 = x[..., :x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class RefRotaryEmbeddingNeox(nn.Module):
+ """Reference implementation of the GPT-NeoX style rotary embedding."""
+
+ def __init__(
+ self,
+ dim: int,
+ max_position_embeddings: int = 2048,
+ base: int = 10000,
+ ) -> None:
+ super().__init__()
+ self.rotary_dim = dim
+ self.max_position_embeddings = max_position_embeddings
+
+ # Create cos and sin embeddings.
+ inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
+ t = torch.arange(max_position_embeddings).float()
+ freqs = torch.einsum("i,j->ij", t, inv_freq.float())
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos().to(dtype=inv_freq.dtype)
+ sin = emb.sin().to(dtype=inv_freq.dtype)
+ self.register_buffer("cos_cached", cos, persistent=False)
+ self.register_buffer("sin_cached", sin, persistent=False)
+
+ def forward(
+ self,
+ positions: torch.Tensor, # [num_tokens]
+ query: torch.Tensor, # [num_tokens, num_heads, head_size]
+ key: torch.Tensor, # [num_tokens, num_heads, head_size]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ query_rot = query[..., :self.rotary_dim]
+ query_pass = query[..., self.rotary_dim:]
+ key_rot = key[..., :self.rotary_dim]
+ key_pass = key[..., self.rotary_dim:]
+
+ query_rot = query_rot.transpose(0, 1)
+ key_rot = key_rot.transpose(0, 1)
+ cos = F.embedding(positions, self.cos_cached)
+ sin = F.embedding(positions, self.sin_cached)
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
+ query_rot = query_rot.transpose(0, 1).contiguous()
+ key_rot = key_rot.transpose(0, 1).contiguous()
+
+ query = torch.cat((query_rot, query_pass), dim=-1)
+ key = torch.cat((key_rot, key_pass), dim=-1)
+
+ # Output query/key shape: [num_tokens, num_tokens, head_size]
+ return query, key
+
+def run_rotary_embedding_neox(
+ num_tokens: int,
+ num_heads: int,
+ head_size: int,
+ max_position: int,
+ rotary_dim: int,
+ dtype: torch.dtype,
+ base: int = 10000,
+) -> None:
+ positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
+ query = torch.randn(num_tokens,
+ num_heads * head_size,
+ dtype=dtype,
+ device='cuda')
+ key = torch.randn(num_tokens,
+ num_heads * head_size,
+ dtype=dtype,
+ device='cuda')
+
+ # Create the rotary embedding.
+ inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
+ t = torch.arange(max_position).float()
+ freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cos_sin_cache = torch.cat((cos, sin), dim=-1)
+ cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
+
+ # Run the kernel. The kernel is in-place, so we need to clone the inputs.
+ out_query = query.clone()
+ out_key = key.clone()
+ rotary_embedding_neox(
+ positions,
+ out_query,
+ out_key,
+ head_size,
+ cos_sin_cache,
+ )
+
+ # Run the reference implementation.
+ ref_rotary_embedding = RefRotaryEmbeddingNeox(
+ dim=rotary_dim,
+ max_position_embeddings=max_position,
+ base=base,
+ ).to(dtype=dtype, device='cuda')
+ ref_query, ref_key = ref_rotary_embedding(
+ positions,
+ query.view(num_tokens, num_heads, head_size),
+ key.view(num_tokens, num_heads, head_size),
+ )
+ ref_query = ref_query.view(num_tokens, num_heads * head_size)
+ ref_key = ref_key.view(num_tokens, num_heads * head_size)
+
+ # Compare the results.
+ assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
+ assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
+
+@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
+def test_rotary_embedding():
+ run_rotary_embedding_neox(
+ num_tokens=1024,
+ num_heads=8,
+ head_size=64,
+ max_position=8192,
+ rotary_dim=64,
+ dtype=torch.float16,
+ )
+
+if __name__ == "__main__":
+ test_rotary_embedding()
\ No newline at end of file
diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py
new file mode 100644
index 000000000000..b081b32b9ad3
--- /dev/null
+++ b/tests/test_infer_ops/triton/kernel_utils.py
@@ -0,0 +1,28 @@
+import math
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
+ '''
+ adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
+ '''
+ xq = xq.view(bs, seqlen, num_head, head_dim)
+ xk = xk.view(bs, seqlen, num_head, head_dim)
+ xv = xv.view(bs, seqlen, num_head, head_dim)
+ mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
+ mask[mask == 0.] = -100000000.0
+ mask = mask.repeat(bs, num_head, 1, 1)
+ keys = xk
+ values = xv
+ xq = xq.transpose(1, 2)
+ keys = keys.transpose(1, 2)
+ values = values.transpose(1, 2)
+ sm_scale = 1 / math.sqrt(head_dim)
+ scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
+ scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
+
+ output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
+ return output
diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py
new file mode 100644
index 000000000000..344ad078e2e2
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py
@@ -0,0 +1,54 @@
+import math
+
+import pytest
+import torch
+from packaging import version
+from torch import nn
+from torch.nn import functional as F
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton import bloom_context_attn_fwd
+ from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test_bloom_context_attention():
+ bs = 4
+ head_num = 8
+ seq_len = 1024
+ head_dim = 64
+
+ query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+ k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+ v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+
+ max_input_len = seq_len
+ b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
+ b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
+
+ for i in range(bs):
+ b_start[i] = i * seq_len
+ b_len[i] = seq_len
+
+ o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+ alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
+ bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
+
+ torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
+
+ assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
+ atol=1e-2), "outputs from triton and torch are not matched"
+
+
+if __name__ == "__main__":
+ test_bloom_context_attention()
diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py
new file mode 100644
index 000000000000..c656f81d2790
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py
@@ -0,0 +1,39 @@
+import pytest
+import torch
+from packaging import version
+from torch import nn
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test_kv_cache_copy_op():
+
+ B_NTX = 32 * 2048
+ head_num = 8
+ head_dim = 64
+
+ cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
+ dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
+
+ dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
+
+ copy_kv_cache_to_dest(cache, dest_index, dest_data)
+
+ assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3,
+ atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
+
+
+if __name__ == "__main__":
+ test_kv_cache_copy_op()
diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py
new file mode 100644
index 000000000000..94cd704ffeba
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_layernorm_triton.py
@@ -0,0 +1,44 @@
+import pytest
+import torch
+from packaging import version
+
+from colossalai.kernel.triton import layer_norm
+from colossalai.testing.utils import parameterize
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+@parameterize('M', [2, 4, 8, 16])
+@parameterize('N', [64, 128])
+def test_layer_norm(M, N):
+ dtype = torch.float16
+ eps = 1e-5
+ x_shape = (M, N)
+ w_shape = (x_shape[-1],)
+ weight = torch.rand(w_shape, dtype=dtype, device='cuda')
+ bias = torch.rand(w_shape, dtype=dtype, device='cuda')
+ x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
+
+ y_triton = layer_norm(x, weight, bias, eps)
+ y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
+
+ assert y_triton.shape == y_torch.shape
+ assert y_triton.dtype == y_torch.dtype
+ print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
+ assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
+
+
+if __name__ == "__main__":
+ test_layer_norm()
diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py
new file mode 100644
index 000000000000..4ea6095d4109
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_llama_context_attention.py
@@ -0,0 +1,53 @@
+import math
+
+import pytest
+import torch
+from packaging import version
+from torch import nn
+from torch.nn import functional as F
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton import llama_context_attn_fwd
+ from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test_llama_context_attention():
+ bs = 4
+ head_num = 8
+ seq_len = 1024
+ head_dim = 64
+
+ query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+ k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+ v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+
+ max_input_len = seq_len
+ b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
+ b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
+
+ for i in range(bs):
+ b_start[i] = i * seq_len
+ b_len[i] = seq_len
+
+ o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
+ llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
+
+ torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
+
+ assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
+ atol=1e-3), "outputs from triton and torch are not matched"
+
+
+if __name__ == "__main__":
+ test_llama_context_attention()
diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py
new file mode 100644
index 000000000000..d5ecdf684538
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_rotary_embedding.py
@@ -0,0 +1,56 @@
+# Adapted from ModelTC https://github.com/ModelTC/lightllm
+
+import time
+
+import pytest
+import torch
+from packaging import version
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+def torch_rotary_emb(x, cos, sin):
+ seq_len, h, dim = x.shape
+ x0 = x[:, :, 0:dim // 2]
+ x1 = x[:, :, dim // 2:dim]
+ cos = cos.view((seq_len, 1, dim // 2))
+ sin = sin.view((seq_len, 1, dim // 2))
+ o0 = x0 * cos - x1 * sin
+ o1 = x0 * sin + x1 * cos
+ return torch.cat((o0, o1), dim=-1)
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test_rotary_emb():
+ SEQ_LEN = 1
+ HEAD_NUM = 32
+ HEAD_DIM = 128
+ dtype = torch.half
+ # create data
+ x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
+ x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
+ cos_shape = (SEQ_LEN, HEAD_DIM // 2)
+ cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
+ sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
+ # forward pass
+ y_torch = torch_rotary_emb(x, cos, sin)
+ rotary_embedding_fwd(x, cos, sin)
+ y_triton = x
+ # compare
+ assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)
+
+
+if __name__ == "__main__":
+ test_rotary_emb()
diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py
similarity index 91%
rename from tests/test_kernels/test_self_attention.py
rename to tests/test_infer_ops/triton/test_self_attention_nonfusion.py
index b316404a58db..9692737a05a0 100644
--- a/tests/test_kernels/test_self_attention.py
+++ b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py
@@ -4,12 +4,11 @@
from torch import nn
import torch.nn.functional as F
-from colossalai.kernel.triton.ops import self_attention_compute_using_triton
-from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
-
try:
import triton
import triton.language as tl
+ from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
+ from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -17,7 +16,7 @@
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
-@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_qkv_matmul():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
scale = 1.2
@@ -106,7 +105,7 @@ def self_attention_compute_using_torch(qkv,
return res.view(batches, -1, d_model), score_output, softmax_output
-@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_self_atttention_test():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
diff --git a/tests/test_kernels/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py
similarity index 70%
rename from tests/test_kernels/test_softmax.py
rename to tests/test_infer_ops/triton/test_softmax.py
index 843d811d019c..6a244608c43f 100644
--- a/tests/test_kernels/test_softmax.py
+++ b/tests/test_infer_ops/triton/test_softmax.py
@@ -3,11 +3,19 @@
import torch
from torch import nn
-from colossalai.kernel.triton.ops import softmax
+
+try:
+ import triton
+ import triton.language as tl
+ from colossalai.kernel.triton.softmax import softmax
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
-@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_softmax_op():
data_samples = [
torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),
diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py
new file mode 100644
index 000000000000..aee7944597dc
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_token_attn_1.py
@@ -0,0 +1,72 @@
+import math
+
+import pytest
+import torch
+from packaging import version
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+def torch_attn(xq, xk, bs, seqlen, num_head, head_dim):
+ xq = xq.view(bs, 1, num_head, head_dim)
+ xk = xk.view(bs, seqlen, num_head, head_dim)
+ keys = xk
+ xq = xq.transpose(1, 2)
+ keys = keys.transpose(1, 2)
+ scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(
+ num_head, -1)
+ return scores
+
+
+def torch_attn_1(xq, xk, seqlen, num_head, head_dim):
+ xq = xq.view(1, num_head, head_dim)
+ xk = xk.view(seqlen, num_head, head_dim)
+ logics = torch.sum(xq * xk, dim=-1, keepdim=False)
+
+ logics = logics.transpose(0, 1) / math.sqrt(head_dim)
+ return logics
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test_attn_1():
+ import time
+
+ batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
+
+ dtype = torch.float16
+
+ q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
+ k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
+ attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
+
+ b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
+ kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
+ kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
+
+ for i in range(batch_size):
+ kv_cache_start_loc[i] = i * seq_len
+ kv_cache_seq_len[i] = seq_len
+ b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
+
+ token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
+
+ torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
+ o = attn_out.squeeze()
+ print("max ", torch.max(torch.abs(torch_out - o)))
+ print("mean ", torch.mean(torch.abs(torch_out - o)))
+ assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
+
+
+if __name__ == "__main__":
+ test_attn_1()
diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py
new file mode 100644
index 000000000000..f834fedbb0f1
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_token_attn_2.py
@@ -0,0 +1,61 @@
+import math
+
+import pytest
+import torch
+from packaging import version
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+def torch_attn(V, P, bs, seqlen, num_head, head_dim):
+ V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2)
+ P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1)
+ attn_out = torch.matmul(P, V)
+
+ return attn_out
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test_token_attn_2():
+ import time
+
+ batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
+ dtype = torch.float16
+
+ V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
+ Prob = torch.empty(
+ (head_num, batch_size * seq_len), dtype=dtype,
+ device="cuda").normal_(mean=0.4, std=0.2).reshape(head_num, batch_size,
+ seq_len).softmax(-1).reshape(head_num, batch_size * seq_len)
+ attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda")
+
+ kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
+ kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
+ kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
+ for i in range(batch_size):
+ kv_cache_start_loc[i] = i * seq_len
+ kv_cache_seq_len[i] = seq_len
+ kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
+
+ token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
+
+ torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze()
+ o = attn_out
+ print("max ", torch.max(torch.abs(torch_out - o)))
+ print("mean ", torch.mean(torch.abs(torch_out - o)))
+ assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
+
+
+if __name__ == "__main__":
+ test_token_attn_2()
diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py
new file mode 100644
index 000000000000..e82318965e05
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py
@@ -0,0 +1,67 @@
+import time
+
+import pytest
+import torch
+from packaging import version
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
+ xq = xq.view(bs, 1, num_head, head_dim)
+ xk = xk.view(bs, seqlen, num_head, head_dim)
+ xv = xv.view(bs, seqlen, num_head, head_dim)
+
+ logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
+ prob = torch.softmax(logics, dim=1)
+ prob = prob.view(bs, seqlen, num_head, 1)
+
+ return torch.sum(prob * xv, dim=1, keepdim=False)
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test():
+
+ Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128
+ dtype = torch.float16
+ q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
+ k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
+ v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
+ o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
+ alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
+
+ max_kv_cache_len = seq_len
+ kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
+ kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
+ kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
+
+ kv_cache_seq_len[:] = seq_len
+ kv_cache_start_loc[0] = 0
+ kv_cache_start_loc[1] = seq_len
+ kv_cache_start_loc[2] = 2 * seq_len
+ kv_cache_start_loc[3] = 3 * seq_len
+
+ for i in range(Z):
+ kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
+
+ token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
+ torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
+
+ print("max ", torch.max(torch.abs(torch_out - o)))
+ print("mean ", torch.mean(torch.abs(torch_out - o)))
+ assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
+
+
+if __name__ == "__main__":
+ test()
diff --git a/tests/test_infer_ops/triton/test_token_softmax.py b/tests/test_infer_ops/triton/test_token_softmax.py
new file mode 100644
index 000000000000..08ffe1ca8323
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_token_softmax.py
@@ -0,0 +1,48 @@
+import pytest
+import torch
+from packaging import version
+
+try:
+ import triton
+ import triton.language as tl
+
+ from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
+ reason="triton requires cuda version to be higher than 11.4")
+def test_softmax():
+
+ import torch
+
+ batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128
+
+ dtype = torch.float16
+
+ Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
+ ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
+
+ kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
+ kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
+
+ for i in range(batch_size):
+ kv_cache_start_loc[i] = i * seq_len
+ kv_cache_seq_len[i] = seq_len
+
+ token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len)
+
+ torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len)
+ o = ProbOut
+ print("max ", torch.max(torch.abs(torch_out - o)))
+ print("mean ", torch.mean(torch.abs(torch_out - o)))
+ assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
+
+
+if __name__ == "__main__":
+ test_softmax()
diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
similarity index 93%
rename from tests/test_comm/test_boardcast_send_recv_v2.py
rename to tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
index 253f6f21cd80..c5fb049fe93f 100644
--- a/tests/test_comm/test_boardcast_send_recv_v2.py
+++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
@@ -1,10 +1,10 @@
import pytest
import torch
-from colossalai.communication.p2p_v2 import _recv_object, _send_object
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
+from colossalai.legacy.communication.p2p_v2 import _recv_object, _send_object
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py
similarity index 96%
rename from tests/test_comm/test_comm.py
rename to tests/test_legacy/test_comm/test_comm.py
index 747596bd2ded..3251d8d46f0b 100644
--- a/tests/test_comm/test_comm.py
+++ b/tests/test_legacy/test_comm/test_comm.py
@@ -2,10 +2,10 @@
import torch
import torch.distributed as dist
-from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
+from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py
similarity index 98%
rename from tests/test_comm/test_object_list_p2p.py
rename to tests/test_legacy/test_comm/test_object_list_p2p.py
index e9d7630c1543..f50982ee1c2d 100644
--- a/tests/test_comm/test_object_list_p2p.py
+++ b/tests/test_legacy/test_comm/test_object_list_p2p.py
@@ -1,7 +1,10 @@
import pytest
import torch
-from colossalai.communication.p2p import (
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.initialize import launch
+from colossalai.legacy.communication.p2p import (
recv_backward,
recv_forward,
send_backward,
@@ -9,9 +12,6 @@
send_forward,
send_forward_recv_backward,
)
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(pipeline=2))
diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
similarity index 97%
rename from tests/test_comm/test_object_list_p2p_v2.py
rename to tests/test_legacy/test_comm/test_object_list_p2p_v2.py
index cae38385b6e1..040c63322f2b 100644
--- a/tests/test_comm/test_object_list_p2p_v2.py
+++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
@@ -1,10 +1,10 @@
import pytest
import torch
-from colossalai.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
+from colossalai.legacy.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py
similarity index 100%
rename from tests/test_engine/test_engine.py
rename to tests/test_legacy/test_engine/test_engine.py
diff --git a/tests/test_engine/test_gradient_accumluation.py b/tests/test_legacy/test_engine/test_gradient_accumluation.py
similarity index 100%
rename from tests/test_engine/test_gradient_accumluation.py
rename to tests/test_legacy/test_engine/test_gradient_accumluation.py
diff --git a/tests/test_layers/test_2p5d/checks_2p5d/__init__.py b/tests/test_legacy/test_layers/test_1d/checks_1d/__init__.py
similarity index 100%
rename from tests/test_layers/test_2p5d/checks_2p5d/__init__.py
rename to tests/test_legacy/test_layers/test_1d/checks_1d/__init__.py
diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
similarity index 99%
rename from tests/test_layers/test_1d/checks_1d/check_layer_1d.py
rename to tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
index 668b8a334800..dcb2be62671b 100644
--- a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py
+++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
@@ -5,7 +5,7 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.nn import (
+from colossalai.legacy.nn import (
Classifier1D,
Embedding1D,
Linear1D_Col,
diff --git a/tests/test_layers/test_1d/checks_1d/common.py b/tests/test_legacy/test_layers/test_1d/checks_1d/common.py
similarity index 94%
rename from tests/test_layers/test_1d/checks_1d/common.py
rename to tests/test_legacy/test_layers/test_1d/checks_1d/common.py
index 8b7b28613d22..29a9a3d20330 100644
--- a/tests/test_layers/test_1d/checks_1d/common.py
+++ b/tests/test_legacy/test_layers/test_1d/checks_1d/common.py
@@ -1,15 +1,16 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch
-
-DEPTH = 4
-BATCH_SIZE = 8
-SEQ_LENGTH = 8
-IMG_SIZE = 16
-HIDDEN_SIZE = 8
-NUM_CLASSES = 8
-VOCAB_SIZE = 16
-
-def check_equal(A, B):
- assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import torch
+
+DEPTH = 4
+BATCH_SIZE = 8
+SEQ_LENGTH = 8
+IMG_SIZE = 16
+HIDDEN_SIZE = 8
+NUM_CLASSES = 8
+VOCAB_SIZE = 16
+
+
+def check_equal(A, B):
+ assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py
similarity index 100%
rename from tests/test_layers/test_1d/test_1d.py
rename to tests/test_legacy/test_layers/test_1d/test_1d.py
diff --git a/tests/test_layers/test_3d/checks_3d/__init__.py b/tests/test_legacy/test_layers/test_2d/checks_2d/__init__.py
similarity index 100%
rename from tests/test_layers/test_3d/checks_3d/__init__.py
rename to tests/test_legacy/test_layers/test_2d/checks_2d/__init__.py
diff --git a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
similarity index 97%
rename from tests/test_layers/test_2d/checks_2d/check_layer_2d.py
rename to tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
index e030e473a363..0ee88c26035f 100644
--- a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
@@ -1,12 +1,23 @@
import torch
+
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D,
- VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D,
- VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D)
+from colossalai.legacy.nn import (
+ Classifier2D,
+ CrossEntropyLoss2D,
+ Embedding2D,
+ LayerNorm2D,
+ Linear2D,
+ PatchEmbedding2D,
+ VanillaClassifier,
+ VanillaPatchEmbedding,
+ VocabParallelClassifier2D,
+ VocabParallelCrossEntropyLoss2D,
+ VocabParallelEmbedding2D,
+)
from colossalai.utils import get_current_device, print_rank_0
-from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal)
+from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear():
@@ -336,7 +347,7 @@ def check_classifier_no_given_weight():
layer.weight.data.copy_(W)
# W.requires_grad = True
- B_shape = (OUTPUT_SIZE, )
+ B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
@@ -572,7 +583,7 @@ def check_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
- target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
+ target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
@@ -607,7 +618,7 @@ def check_vocab_parallel_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
- target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
+ target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
diff --git a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
similarity index 96%
rename from tests/test_layers/test_2d/checks_2d/check_operation_2d.py
rename to tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
index a5e37b1ec309..ae1d1120cfb9 100644
--- a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
@@ -5,10 +5,10 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
-from colossalai.utils import get_current_device
-from colossalai.utils import print_rank_0
-from .common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH
+from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
+from colossalai.utils import get_current_device, print_rank_0
+
+from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal
def check_AB():
diff --git a/tests/test_layers/test_2d/checks_2d/common.py b/tests/test_legacy/test_layers/test_2d/checks_2d/common.py
similarity index 100%
rename from tests/test_layers/test_2d/checks_2d/common.py
rename to tests/test_legacy/test_layers/test_2d/checks_2d/common.py
diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py
similarity index 100%
rename from tests/test_layers/test_2d/test_2d.py
rename to tests/test_legacy/test_layers/test_2d/test_2d.py
diff --git a/tests/test_layers/test_sequence/checks_seq/__init__.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/__init__.py
similarity index 100%
rename from tests/test_layers/test_sequence/checks_seq/__init__.py
rename to tests/test_legacy/test_layers/test_2p5d/checks_2p5d/__init__.py
diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
similarity index 98%
rename from tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
rename to tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
index a8f551093b1e..5a99b05cfe7e 100644
--- a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
@@ -1,11 +1,22 @@
import torch
+from torch.nn import Parameter
+
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D,
- PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D,
- VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D)
+from colossalai.legacy.nn import (
+ Classifier2p5D,
+ CrossEntropyLoss2p5D,
+ Embedding2p5D,
+ LayerNorm2p5D,
+ Linear2p5D,
+ PatchEmbedding2p5D,
+ VanillaClassifier,
+ VanillaPatchEmbedding,
+ VocabParallelClassifier2p5D,
+ VocabParallelCrossEntropyLoss2p5D,
+ VocabParallelEmbedding2p5D,
+)
from colossalai.utils import get_current_device, print_rank_0
-from torch.nn import Parameter
from .common import *
@@ -342,7 +353,7 @@ def check_classifier_no_given_weight():
layer.weight.data.copy_(W)
# W.requires_grad = True
- B_shape = (OUTPUT_SIZE, )
+ B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
@@ -577,7 +588,7 @@ def check_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
- target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
+ target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
@@ -612,7 +623,7 @@ def check_vocab_parallel_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
- target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
+ target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
similarity index 97%
rename from tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
rename to tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
index d0c3b02fccba..db19967676d2 100644
--- a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
@@ -2,10 +2,9 @@
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \
- Matmul_ATB_2p5D
-from colossalai.utils import get_current_device
-from colossalai.utils import print_rank_0
+from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D
+from colossalai.utils import get_current_device, print_rank_0
+
from .common import *
diff --git a/tests/test_layers/test_2p5d/checks_2p5d/common.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/common.py
similarity index 75%
rename from tests/test_layers/test_2p5d/checks_2p5d/common.py
rename to tests/test_legacy/test_layers/test_2p5d/checks_2p5d/common.py
index aff85f109666..c90d8fc086bd 100644
--- a/tests/test_layers/test_2p5d/checks_2p5d/common.py
+++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/common.py
@@ -11,4 +11,4 @@
def check_equal(A, B):
- assert torch.allclose(A, B, rtol=1e-5, atol=1e-2)
\ No newline at end of file
+ assert torch.allclose(A, B, rtol=1e-5, atol=1e-2)
diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
similarity index 100%
rename from tests/test_layers/test_2p5d/test_2p5d.py
rename to tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/__init__.py b/tests/test_legacy/test_layers/test_3d/checks_3d/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
similarity index 99%
rename from tests/test_layers/test_3d/checks_3d/check_layer_3d.py
rename to tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
index e946a1f5912d..cee639a9f00a 100644
--- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py
+++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
@@ -7,8 +7,7 @@
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context
-from colossalai.logging import get_dist_logger
-from colossalai.nn import (
+from colossalai.legacy.nn import (
Classifier3D,
CrossEntropyLoss3D,
Embedding3D,
@@ -21,7 +20,8 @@
VocabParallelCrossEntropyLoss3D,
VocabParallelEmbedding3D,
)
-from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, print_rank_0
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_legacy/test_layers/test_3d/checks_3d/common.py
similarity index 95%
rename from tests/test_layers/test_3d/checks_3d/common.py
rename to tests/test_legacy/test_layers/test_3d/checks_3d/common.py
index afb19c4745cc..509fc2cecf59 100644
--- a/tests/test_layers/test_3d/checks_3d/common.py
+++ b/tests/test_legacy/test_layers/test_3d/checks_3d/common.py
@@ -16,4 +16,4 @@
def check_equal(A, B):
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
assert eq, f"\nA = {A}\nB = {B}"
- return eq
\ No newline at end of file
+ return eq
diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py
similarity index 100%
rename from tests/test_layers/test_3d/test_3d.py
rename to tests/test_legacy/test_layers/test_3d/test_3d.py
diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py
similarity index 99%
rename from tests/test_layers/test_cache_embedding.py
rename to tests/test_legacy/test_layers/test_cache_embedding.py
index 22d4f02a48d7..0760a3f1ec38 100644
--- a/tests/test_layers/test_cache_embedding.py
+++ b/tests/test_legacy/test_layers/test_cache_embedding.py
@@ -6,7 +6,7 @@
import torch
import colossalai
-from colossalai.nn.parallel.layers import (
+from colossalai.legacy.nn.parallel.layers import (
CachedEmbeddingBag,
CachedParamMgr,
EvictionStrategy,
diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/__init__.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
similarity index 91%
rename from tests/test_layers/test_sequence/checks_seq/check_layer_seq.py
rename to tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
index 2b7b999d4373..7ff91a7b76e0 100644
--- a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py
+++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
@@ -2,7 +2,7 @@
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn import TransformerSelfAttentionRing
+from colossalai.legacy.nn import TransformerSelfAttentionRing
from colossalai.utils import get_current_device
diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_legacy/test_layers/test_sequence/test_sequence.py
similarity index 97%
rename from tests/test_layers/test_sequence/test_sequence.py
rename to tests/test_legacy/test_layers/test_sequence/test_sequence.py
index 60f2d55f43af..b9e6c12479ee 100644
--- a/tests/test_layers/test_sequence/test_sequence.py
+++ b/tests/test_legacy/test_layers/test_sequence/test_sequence.py
@@ -5,6 +5,7 @@
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
@@ -42,7 +43,7 @@ def check_ring_qk(rank, world_size):
a = torch.matmul(q, k.transpose(2, 1))
# compute distributed attention scores
- ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply
+ ring_qk = RingQK.apply
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attention scores
@@ -95,7 +96,7 @@ def check_ring_av(rank, world_size):
out = torch.matmul(a, v)
# compute distributed attention scores
- ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply
+ ring_av = RingAV.apply
sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
similarity index 98%
rename from tests/test_trainer/test_pipeline/test_p2p.py
rename to tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
index 8ad366133d18..5fb678525bb3 100644
--- a/tests/test_trainer/test_pipeline/test_p2p.py
+++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
@@ -5,7 +5,10 @@
import torch
import torch.distributed as dist
-from colossalai.communication import (
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.initialize import launch
+from colossalai.legacy.communication import (
recv_backward,
recv_forward,
recv_obj_meta,
@@ -15,9 +18,6 @@
send_forward_recv_backward,
send_obj_meta,
)
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
similarity index 100%
rename from tests/test_trainer/test_pipeline/test_pipeline_schedule.py
rename to tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
similarity index 97%
rename from tests/test_trainer/test_trainer_with_non_pipe_schedule.py
rename to tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
index 753f82222f9d..dab0e53a4c32 100644
--- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py
+++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
@@ -3,9 +3,9 @@
import colossalai
from colossalai.amp.amp_type import AMP_TYPE
+from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer
from tests.components_to_test.registry import non_distributed_component_funcs
diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
similarity index 98%
rename from tests/test_trainer/test_trainer_with_pipe_schedule.py
rename to tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
index bb63d51a0b65..7dfbec854ccc 100644
--- a/tests/test_trainer/test_trainer_with_pipe_schedule.py
+++ b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
@@ -12,9 +12,9 @@
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, get_dataloader
BATCH_SIZE = 4
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py
index e7002a75f3f7..9c84a99cd549 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_moe/test_grad_handler.py
@@ -5,7 +5,7 @@
import colossalai
from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.engine.gradient_handler import MoeGradientHandler
+from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py
index ec37967f18c5..595d4374df6f 100644
--- a/tests/test_moe/test_moe_zero_model.py
+++ b/tests/test_moe/test_moe_zero_model.py
@@ -3,7 +3,7 @@
import colossalai
from colossalai.context import MOE_CONTEXT
-from colossalai.engine.gradient_handler import MoeGradientHandler
+from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero.legacy.init_ctx import ZeroInitContext
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
index efc6e9ddae27..a43ae764dccd 100644
--- a/tests/test_moe/test_moe_zero_optim.py
+++ b/tests/test_moe/test_moe_zero_optim.py
@@ -4,7 +4,7 @@
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.context import MOE_CONTEXT
-from colossalai.engine.gradient_handler import MoeGradientHandler
+from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py
deleted file mode 100644
index 6a0509555862..000000000000
--- a/tests/test_pipeline/test_cuda_rpc_performance.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import os
-from typing import Callable, List, Optional, Type, Union
-import time
-
-import pytest
-import torch
-import torch.nn as nn
-from titans.dataloader.cifar10 import build_cifar
-from torchvision.models import resnet50
-from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
-from tqdm import tqdm
-
-from rpc_test_utils import rpc_run, parse_args
-import colossalai
-import colossalai.nn as col_nn
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
-from colossalai.utils import MultiTimer, get_dataloader
-from colossalai.context import ParallelMode
-from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
-from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine
-from colossalai.pipeline.pipeline_process_group import ppg
-
-
-def flatten(x):
- return torch.flatten(x, 1)
-
-
-def partition(pp_rank: int, chunk: int, stage_num: int):
- pipelinable = PipelinableContext()
-
- # build model partitions
- with pipelinable:
- # input : [B, 3, 32, 32]
- _ = resnet50()
-
- pipelinable.policy = "customized"
-
- exec_seq = [
- 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
- ]
- pipelinable.to_layer_list(exec_seq)
- partition = pipelinable.partition(chunk, stage_num, pp_rank)
- return partition
-
-
-def run_master(args):
- batch_size = args.batch_size
- chunk = args.chunk
- device = args.device
- world_size = args.world_size
- stage_num = world_size
- num_microbatches = args.num_microbatches
-
- # build dataloader
- root = os.environ.get('DATA', './data')
- train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
- criterion = nn.CrossEntropyLoss()
-
- pp_engine = OneFOneBPipelineEngine(partition_fn=partition,
- stage_num=stage_num,
- num_microbatches=num_microbatches,
- device=device,
- chunk=chunk,
- criterion=criterion,
- checkpoint=False)
-
- pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
- s = time.time()
-
- for bx, by in tqdm(train_dataloader):
- pp_engine.forward_backward(bx, labels=by, forward_only=False)
-
- cost_time = time.time() - s
-
- print("total cost time :", cost_time)
- print("cost time per batch:", cost_time / len(train_dataloader))
-
-
-@pytest.mark.skip("Test for performance, no need for CI")
-def main():
- args = parse_args()
- # this is due to limitation of partition function
- args.world_size = 2
- args.chunk = 1
- rpc_run(args, run_master)
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py
new file mode 100644
index 000000000000..0cbb852b97a0
--- /dev/null
+++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py
@@ -0,0 +1,39 @@
+from colossalai.shardformer.policies.t5 import T5BasePolicy
+
+
+def test_t5_pipeline_distribution():
+ num_test_cases = 8
+ test_dict = {
+ 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
+ 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
+ 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
+ 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
+ }
+
+ for i in range(num_test_cases):
+ _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
+ test_dict['num_decoder_layers'][i],
+ test_dict['num_stages'][i])
+ assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
+
+
+def test_t5_pipeline_layers():
+ num_test_cases = 4
+ test_dict = {
+ 'num_encoder_layers': [2, 3, 2, 4],
+ 'num_decoder_layers': [2, 0, 2, 8],
+ 'num_stages': [2, 2, 4, 4],
+ 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
+ [[0, 4], [0, 3], [3, 6], [6, 8]]]
+ }
+
+ for i in range(num_test_cases):
+ layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
+ test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
+
+ for stage in range(test_dict['num_stages'][i]):
+ start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
+ predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
+ decoder_starting_stage)
+ assert start_idx == predicted_start
+ assert end_idx == predicted_end
diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py
new file mode 100644
index 000000000000..395519e97898
--- /dev/null
+++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py
@@ -0,0 +1,44 @@
+from colossalai.shardformer.policies.whisper import WhisperPolicy
+
+
+def test_whisper_pipeline_distribution():
+ num_test_cases = 8
+ test_dict = {
+ 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
+ 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
+ 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
+ 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
+ }
+
+ for i in range(num_test_cases):
+ _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i],
+ test_dict['num_decoder_layers'][i],
+ test_dict['num_stages'][i])
+ assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
+
+
+def test_whisper_pipeline_layers():
+ num_test_cases = 4
+ test_dict = {
+ 'num_encoder_layers': [2, 3, 2, 4],
+ 'num_decoder_layers': [2, 0, 2, 8],
+ 'num_stages': [2, 2, 4, 4],
+ 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
+ [[0, 4], [0, 3], [3, 6], [6, 8]]]
+ }
+
+ for i in range(num_test_cases):
+ layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
+
+ for stage in range(test_dict['num_stages'][i]):
+ start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
+ predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage,
+ decoder_starting_stage)
+ assert start_idx == predicted_start
+ assert end_idx == predicted_end
+
+
+if __name__ == '__main__':
+ test_whisper_pipeline_distribution()
+ test_whisper_pipeline_layers()
diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py
new file mode 100644
index 000000000000..a995d17e5da6
--- /dev/null
+++ b/tests/test_pipeline/test_schedule/test_interleaved.py
@@ -0,0 +1,161 @@
+import copy
+from functools import partial
+from types import MethodType
+
+import pytest
+import torch
+import torch.nn as nn
+
+import colossalai
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.interface import OptimizerWrapper
+from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+
+
+class MlpModel(nn.Module):
+
+ def __init__(self):
+ super(MlpModel, self).__init__()
+ self.linear1 = nn.Linear(4, 8)
+ self.linear2 = nn.Linear(8, 8)
+ self.linear3 = nn.Linear(8, 8)
+ self.linear4 = nn.Linear(8, 8)
+ self.linear5 = nn.Linear(8, 8)
+ self.linear6 = nn.Linear(8, 8)
+ self.linear7 = nn.Linear(8, 8)
+ self.linear8 = nn.Linear(8, 4)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.linear2(x)
+ x = self.linear3(x)
+ x = self.linear4(x)
+ x = self.linear5(x)
+ x = self.linear6(x)
+ x = self.linear7(x)
+ x = self.linear8(x)
+ return x
+
+
+def pp_linear_fwd(forward,
+ data: torch.Tensor = None,
+ input_obj: torch.Tensor = None,
+ stage_mgr: PipelineStageManager = None,
+ num_chunks: int = None,
+ model_chunk_id: int = None):
+
+ if stage_mgr.is_first_stage() and model_chunk_id == 0:
+ return {'input_obj': forward(data)}
+ elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1:
+ return forward(input_obj)
+ else:
+ return {'input_obj': forward(input_obj)}
+
+
+@parameterize("num_micro_batches", [4, 8, 12])
+def examine_pp(num_micro_batches):
+ """
+ This test is to examine the correctness of interleaved 1F1B, compared with torch.
+ Be aware it contains some hardcodes.
+ """
+ world_size = torch.distributed.get_world_size()
+ local_rank = torch.distributed.get_rank()
+ seed_all(1453)
+
+ NUM_MICRO_BATCHS = num_micro_batches
+ BATCH_SIZE = num_micro_batches
+ NUM_CHUNKS = 2
+
+ # create model
+ torch_model = MlpModel().cuda()
+
+ pp_model = copy.deepcopy(torch_model).cuda()
+
+ DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
+ pg_mesh = ProcessGroupMesh(1, world_size, 1)
+ stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True)
+ schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager)
+
+ sharded_model = torch.nn.ModuleList()
+ for idx, (_, sub_model) in enumerate(pp_model.named_children()):
+ if idx % (world_size) == local_rank:
+ sub_model._forward = sub_model.forward
+ sub_model.forward = MethodType(
+ partial(pp_linear_fwd,
+ stage_mgr=stage_manager,
+ num_chunks=NUM_CHUNKS,
+ model_chunk_id=len(sharded_model)), sub_model._forward)
+ sharded_model.append(sub_model.cuda())
+
+ # create optimizer
+ torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+ pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))
+
+ # create
+ seed_all(1453)
+ if local_rank == 0:
+ input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
+ else:
+ input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
+ torch.distributed.all_reduce(input_list[0])
+
+ criterion = lambda x, y: torch.mean(x)
+
+ # forward and backward
+ torch_output = torch_model(input_list[0])
+ torch_loss = criterion(torch_output, _)
+ torch_loss.backward()
+
+ pp_ret = schedule.forward_backward_step(sharded_model,
+ iter(input_list),
+ criterion,
+ pp_optimizer,
+ return_loss=True,
+ return_outputs=True)
+
+ # check loss
+ if stage_manager.is_last_stage():
+ assert torch.allclose(torch_loss, pp_ret['loss'])
+
+ # check gradients
+ torch_grad = []
+ for torch_p in torch_model.parameters():
+ torch_grad.append(torch_p.grad.data)
+
+ for idx, pp_p in enumerate(sharded_model.parameters()):
+ if idx < 2:
+ assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
+ else:
+ assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data)
+
+ # step
+ torch_optimizer.step()
+ pp_optimizer.step()
+
+ # check updated param
+ torch_param = []
+ for torch_p in torch_model.parameters():
+ torch_param.append(torch_p.data)
+ for idx, pp_p in enumerate(sharded_model.parameters()):
+ if idx < 2:
+ assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
+ else:
+ assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data)
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+ examine_pp()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_pp():
+ spawn(run_dist, 4)
+
+
+if __name__ == '__main__':
+ test_pp()
diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py
index 542116a1da75..41b535573c39 100644
--- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py
+++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py
@@ -61,7 +61,7 @@ def examine_pp():
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
- schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager)
+ schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:
@@ -90,9 +90,9 @@ def examine_pp():
torch_loss.backward()
pp_ret = schedule.forward_backward_step(sharded_model,
- pp_optimizer,
iter(input_list),
criterion,
+ pp_optimizer,
return_loss=True,
return_outputs=True)
diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py
index be4591d58f74..6e0cd1998c11 100644
--- a/tests/test_pipeline/test_stage_manager.py
+++ b/tests/test_pipeline/test_stage_manager.py
@@ -49,15 +49,6 @@ def check_stage_manager():
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
assert stage_manager.get_next_rank() == next_rank
- # check virtual stage
- stage_manager.set_num_virtual_stages(PP_SIZE * 2)
- assert stage_manager.num_virtual_stages == PP_SIZE * 2
- stage_manager.set_virtual_stage(stage_manager.stage * 2)
- assert stage_manager.virtual_stage == stage_manager.stage * 2
- with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
- assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
- assert stage_manager.virtual_stage == stage_manager.stage * 2
-
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
index b45cd172c3ca..4c0f884a7ed5 100644
--- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
@@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor
-@parameterize('lazy_init', [False, True])
-def check_linear_conv_1d_col(lazy_init: bool):
+def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@@ -62,7 +61,9 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
- n_fused=3)
+ seq_parallel=seq_parallel,
+ n_fused=3,
+ overlap=overlap)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
@@ -76,10 +77,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear.load_state_dict(linear_conv_col.state_dict())
# check computation correctness
- x = torch.rand(4, 48).cuda()
+ x = torch.rand(1, 4, 48).cuda()
out = linear(x)
- gather_out = linear_conv_col(x)
- assert_close(rearrange(out, 1), gather_out)
+ x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
+ gather_out = linear_conv_col(x_for_shard)
+ assert_close(rearrange(out, -1), gather_out)
# check backward correctness
out.sum().backward()
@@ -89,14 +91,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
assert_close(target_grad, linear_conv_col.weight.grad)
-@parameterize('lazy_init', [False, True])
-def check_linear_conv_1d_row(lazy_init: bool):
+def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
- linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
+ linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy,
+ process_group=None,
+ parallel_input=False,
+ seq_parallel=seq_parallel)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
@@ -109,10 +113,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
- x = torch.rand(4, 48).cuda()
+ x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_row(x)
- assert_close(out, gather_out)
+ target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
+ assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
@@ -123,12 +128,19 @@ def check_linear_conv_1d_row(lazy_init: bool):
assert_close(target_grad, linear_row.weight.grad)
+@parameterize('lazy_init', [False, True])
+@parameterize('seq_parallel', [False, True])
+@parameterize('overlap', [True])
+def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
+ check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
+ check_linear_conv_1d_row(lazy_init, seq_parallel)
+
+
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# test for linear conv
- check_linear_conv_1d_col()
- check_linear_conv_1d_row()
+ check_gpt2_qkv_fused_linear_1d()
@rerun_if_address_is_in_use()
diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py
index aa75879e0313..e6d86d533ed6 100644
--- a/tests/test_shardformer/test_layer/test_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_linear_1d.py
@@ -12,13 +12,16 @@
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-@parameterize('lazy_init', [False, True])
-def check_linear_1d_col(lazy_init: bool):
+def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
- linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
+ linear_col = Linear1D_Col.from_native_module(linear_copy,
+ process_group=None,
+ gather_output=True,
+ seq_parallel=seq_parallel,
+ overlap=overlap)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
@@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
linear_col.load_state_dict(linear.state_dict())
# check computation correctness
- x = torch.rand(4, 32).cuda()
+ # [batch_size, seq_len, hidden_size]
+ x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
- x_for_shard = x.expand_as(x.clone())
+ x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
@@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
- assert_close(x_for_unshard.grad, x_for_shard.grad)
+ target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
+ x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
+ assert_close(target_unshard_gard, x_for_shard.grad)
-@parameterize('lazy_init', [False, True])
-def check_linear_1d_row(lazy_init: bool):
+def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
- linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
+ linear_row = Linear1D_Row.from_native_module(linear_copy,
+ process_group=None,
+ parallel_input=False,
+ seq_parallel=seq_parallel)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
@@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
linear_row.load_state_dict(linear.state_dict())
# check computation correctness
- x = torch.rand(4, 32).cuda()
+ # [batch_size, seq_len, hidden_size]
+ x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
@@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
- assert_close(out, gather_out)
+ target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
+ assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
@@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)
-@parameterize('lazy_init', [False, True])
-def check_linear_col_plus_row(lazy_init: bool):
+def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
@@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
with ctx:
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
- linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
- linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
+ linear_col = Linear1D_Col.from_native_module(linear_1_copy,
+ process_group=None,
+ gather_output=False,
+ seq_parallel=seq_parallel,
+ overlap=overlap)
+ linear_row = Linear1D_Row.from_native_module(linear_2_copy,
+ process_group=None,
+ parallel_input=True,
+ seq_parallel=seq_parallel)
linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict())
@@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness
- x = torch.rand(4, 32).cuda()
+ # [batch_size, seq_len, hidden_size]
+ x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
- x_for_shard = x.expand_as(x.clone())
+ x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
- assert_close(unshard_out, shard_out)
+ target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
+ assert_close(target_out, shard_out)
# check backward correctness
unshard_out.sum().backward()
@@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
- assert_close(x_for_unshard.grad, x_for_shard.grad)
+ target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
+ x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
+ assert_close(target_unshard_gard, x_for_shard.grad)
+
+
+@parameterize('lazy_init', [False, True])
+@parameterize('seq_parallel', [False, True])
+@parameterize('overlap', [True])
+def run_dist_linear_test(lazy_init, seq_parallel, overlap):
+ check_linear_1d_col(lazy_init, seq_parallel, overlap)
+ check_linear_1d_row(lazy_init, seq_parallel)
+ check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
-def run_dist(rank, world_size, port):
+def check_dist_linear(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- check_linear_1d_col()
- check_linear_1d_row()
- check_linear_col_plus_row()
+ run_dist_linear_test()
@rerun_if_address_is_in_use()
def test_linear():
- spawn(run_dist, nprocs=2)
+ spawn(check_dist_linear, nprocs=2)
if __name__ == '__main__':
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 921af2a8b1d0..c9c6447a43f0 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -1,4 +1,5 @@
import copy
+import math
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
@@ -12,6 +13,7 @@
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
+from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@@ -25,6 +27,7 @@ def build_model(model_fn,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
+ enable_sequence_parallelism=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
@@ -38,7 +41,8 @@ def build_model(model_fn,
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
- enable_jit_fused=enable_jit_fused)
+ enable_jit_fused=enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
@@ -135,6 +139,16 @@ def _criterion(outputs, inputs):
return loss
data = data_gen_fn()
+
+ if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
+ seq_len = data['input_ids'].shape[-1]
+ lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
+ times = lcm // seq_len
+ input_shape = data['input_ids'].shape
+ for k, v in data.items():
+ if v.shape == input_shape:
+ data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,))
+
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():
@@ -177,11 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
org_hidden_state = org_output.last_hidden_state
- if stage_manager is None:
- sharded_hidden_state = sharded_output.last_hidden_state
-
if stage_manager and stage_manager.is_last_stage():
- sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
+ sharded_hidden_state = sharded_output['outputs']['last_hidden_state']
+ else:
+ sharded_hidden_state = sharded_output.last_hidden_state
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
@@ -219,6 +232,43 @@ def check_weight(org_model: Module,
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
+def get_grad_tensors_for_check(org_model: Module,
+ sharded_model: Module,
+ layer_suffix: List[str],
+ tp_group: ProcessGroup = None,
+ dim: int = 0,
+ atol: float = 1e-5,
+ rtol: float = 1e-3,
+ verbose: bool = False,
+ name: str = None):
+
+ grad_to_check = {}
+ for suffix in layer_suffix:
+ org_grad = getattr_(org_model, suffix).weight.grad
+ shard_grad = getattr_(sharded_model, suffix).weight.grad
+ shard_weight = getattr_(sharded_model, suffix).weight
+ if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
+ shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
+ dist.all_gather(shard_grad_list, shard_grad, tp_group)
+ shard_grad = torch.cat(shard_grad_list, dim=dim)
+
+ # embedding may be resized when using tensor parallel
+ if shard_grad.shape[0] > org_grad.shape[0]:
+ shard_grad = shard_grad[:org_grad.shape[0], :]
+ if verbose and dist.get_rank() == 0:
+ print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
+
+ grad_to_check[suffix] = {
+ "org_grad": org_grad.float(),
+ "shard_grad": shard_grad.float(),
+ "rtol": rtol,
+ "atol": atol
+ }
+
+ return grad_to_check
+
+
+# used by sam/blip2
def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
@@ -231,7 +281,6 @@ def check_grad(org_model: Module,
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
-
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
@@ -246,3 +295,30 @@ def check_grad(org_model: Module,
assert torch.allclose(
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
+
+
+def unwrap_model(module: Module,
+ base_model_class_name: Optional[str] = None,
+ base_model_attribute_name: Optional[str] = None):
+ if isinstance(module, HybridParallelModule):
+ module = module.unwrap()
+ if base_model_class_name is None:
+ return module
+ if module.__class__.__name__ == base_model_class_name:
+ return module
+ return getattr(module, base_model_attribute_name, None)
+
+
+def check_all_grad_tensors(check_tensors):
+ """
+ "org_grad": tensor to be compared from the original model
+ "shard_grad": tensor to be compared from the sharded model
+ """
+ for suffix, check_info in check_tensors.items():
+ org_grad = check_info["org_grad"]
+ shard_grad = check_info["shard_grad"]
+ rtol = check_info["rtol"]
+ atol = check_info["atol"]
+ assert torch.allclose(
+ org_grad, shard_grad, atol=atol, rtol=rtol
+ ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index 0a24e46d28f2..c779e417052b 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -10,11 +10,13 @@
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
@@ -32,42 +34,58 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
output_transform_fn,
criterion,
booster)
+
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
- # check last hidden state & loss
- if stage_manager is None or stage_manager.is_last_stage():
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-5, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
- if org_model.__class__.__name__ == 'BertModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- # unwrap model
- if org_model.__class__.__name__ == 'BertModel':
- bert = org_model
- sharded_bert = sharded_model.unwrap()
- else:
- bert = org_model.bert
- sharded_bert = sharded_model.unwrap().bert
+ bert = unwrap_model(org_model, 'BertModel', 'bert')
+ sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
- if stage_manager is None or stage_manager.is_first_stage():
- #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
- #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
- check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
- check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
-
- # check weights after optimizer.step()
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ col_layer_grads = get_grad_tensors_for_check(bert,
+ sharded_bert,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ row_layer_grads = get_grad_tensors_for_check(bert,
+ sharded_bert,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
+
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ if org_model.__class__.__name__ == 'BertModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
+
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+
+ # check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
else:
@@ -75,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
torch.cuda.empty_cache()
@@ -98,6 +119,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 1,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
def run_bert_test(test_config):
@@ -111,12 +155,50 @@ def run_bert_test(test_config):
torch.cuda.empty_cache()
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'zero_stage': 1,
+ 'initial_scale': 1,
+ },
+])
+def run_bert_3d_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_test()
+def check_bert_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_bert_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -124,5 +206,13 @@ def test_bert():
spawn(check_bert, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_bert_3d():
+ spawn(check_bert_3d, 8)
+
+
if __name__ == "__main__":
test_bert()
+ test_bert_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index ed0d1d8e401d..c9ee690c86dc 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -3,16 +3,19 @@
import colossalai
from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
@@ -34,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ # unwrap model
+ bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
+ sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
+
+ row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
+ col_layer_for_check = ['h[0].self_attention.dense']
+
+ # Save gradient tensors for comparison between the original model and the sharded model.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-5
+ else:
+ atol, rtol = 5e-3, 5e-3
+ row_layer_grads = get_grad_tensors_for_check(bloom,
+ sharded_bloom,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+ col_layer_grads = get_grad_tensors_for_check(bloom,
+ sharded_bloom,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@@ -45,28 +85,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- # unwrap model
- if org_model.__class__.__name__ == 'BloomModel':
- bloom = org_model
- sharded_bloom = sharded_model.unwrap()
- else:
- bloom = org_model.transformer
- sharded_bloom = sharded_model.unwrap().transformer
-
- # check grad
- row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
- col_layer_for_check = ['h[0].self_attention.dense']
- if stage_manager is None or stage_manager.is_first_stage():
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-6, 1e-5
- else:
- atol, rtol = 5e-3, 5e-3
- check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
- check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
-
- # check weights after optimizer.step()
- org_optimizer.step()
- sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@@ -74,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
torch.cuda.empty_cache()
@@ -97,18 +118,72 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
- 'precision': 'fp32',
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 1,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
def run_bloom_test(test_config):
- # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'zero_stage': 1,
+ 'initial_scale': 1,
+ },
+])
+def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -118,6 +193,12 @@ def check_bloom(rank, world_size, port):
run_bloom_test()
+def check_bloom_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_bloom_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -125,5 +206,13 @@ def test_bloom():
spawn(check_bloom, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_bloom_3d():
+ spawn(check_bloom_3d, 8)
+
+
if __name__ == "__main__":
test_bloom()
+ test_bloom_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py
similarity index 52%
rename from tests/test_shardformer/test_model/test_shard_chatglm.py
rename to tests/test_shardformer/test_model/test_shard_chatglm2.py
index bb77759048b3..48f651c727f4 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py
@@ -4,16 +4,19 @@
import colossalai
from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
@@ -35,6 +38,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ # unwrap model
+ chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
+ shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
+
+ row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
+ col_layer_for_check = ['encoder.layers[0].self_attention.dense']
+
+ # Save gradient tensors for comparison between the original model and the sharded model.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ row_layer_grads = get_grad_tensors_for_check(chatglm_model,
+ shard_chatglm_model,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+
+ col_layer_grads = get_grad_tensors_for_check(chatglm_model,
+ shard_chatglm_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@@ -47,43 +88,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- # unwrap model
- if org_model.__class__.__name__ == 'ChatGLMModel':
- chatglm_model = org_model
- shard_chatglm_model = sharded_model.unwrap()
- else:
- chatglm_model = org_model.transformer
- shard_chatglm_model = sharded_model.unwrap().transformer
-
- # check grad
- row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
- col_layer_for_check = ['encoder.layers[0].self_attention.dense']
- if stage_manager is None or stage_manager.is_first_stage():
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-6, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
- check_grad(chatglm_model,
- shard_chatglm_model,
- row_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=0,
- verbose=False)
-
- check_grad(chatglm_model,
- shard_chatglm_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False)
-
- # check weights after optimizer.step()
- org_optimizer.step()
- sharded_optimizer.step()
+ # check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@@ -98,6 +103,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -121,12 +130,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
- 'precision': 'fp32',
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
def run_chatglm_test(test_config):
- # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+ clear_layout_converter()
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'zero_stage': 1,
+ 'initial_scale': 1,
+ },
+])
+def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@@ -142,6 +194,12 @@ def check_chatglm(rank, world_size, port):
run_chatglm_test()
+def check_chatglm_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_chatglm_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -149,5 +207,13 @@ def test_chatglm():
spawn(check_chatglm, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_chatglm_3d():
+ spawn(check_chatglm_3d, 8)
+
+
if __name__ == "__main__":
test_chatglm()
+ test_chatglm_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 1a81b3360655..c4cc3812dbfd 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -3,18 +3,20 @@
from torch import distributed as dist
import colossalai
-from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
@@ -36,6 +38,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ # unwrap model
+ gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
+ sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
+
+ col_layer_for_check = ['h[0].mlp.c_fc']
+ row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
+
+ # Save gradient tensors for comparison between the original model and the sharded model.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ col_layer_grads = get_grad_tensors_for_check(gpt2,
+ sharded_gpt2,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ row_layer_grads = get_grad_tensors_for_check(gpt2,
+ sharded_gpt2,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@@ -48,32 +87,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- def unwrap(module):
- if isinstance(module, HybridParallelModule):
- module = module.unwrap()
- if module.__class__.__name__ == 'GPT2Model':
- return module
- return module.transformer
-
- # unwrap model
- gpt2 = unwrap(org_model)
- sharded_gpt2 = unwrap(sharded_model)
-
- col_layer_for_check = ['h[0].mlp.c_fc']
- row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
-
- # check grad
- if stage_manager is None or stage_manager.is_first_stage():
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-4, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
- check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
- check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
-
- # check weights after optimizer.step()
- org_optimizer.step()
- sharded_optimizer.step()
+ # check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
@@ -81,6 +95,10 @@ def unwrap(module):
atol, rtol = 5e-3, 5e-3
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -106,12 +124,72 @@ def unwrap(module):
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+}, {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp32',
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 1,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):
- # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ torch.cuda.empty_cache()
+
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'zero_stage': 1,
+ 'initial_scale': 1,
+ },
+])
+@clear_cache_before_run()
+def run_gpt2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@@ -127,10 +205,12 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test()
-# TODO(ver217): fix this
+def check_gpt2_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_gpt2_3d_test()
-@pytest.mark.skip("this will stuck in CI")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -138,5 +218,13 @@ def test_gpt2():
spawn(check_gpt2, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_gpt2_3d():
+ spawn(check_gpt2_3d, 8)
+
+
if __name__ == "__main__":
test_gpt2()
+ test_gpt2_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index 30ebdfbe5cd9..a60150e3cd72 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -6,16 +6,19 @@
import colossalai
from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@@ -39,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ # unwrap model
+ llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
+ shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
+
+ row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
+ col_layer_for_check = ['layers[0].self_attn.o_proj']
+
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-4
+ else:
+ atol, rtol = 5e-3, 5e-3
+ row_layer_grads = get_grad_tensors_for_check(llama_model,
+ shard_llama_model,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+ col_layer_grads = get_grad_tensors_for_check(llama_model,
+ shard_llama_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@@ -51,42 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- # unwrap model
- if org_model.__class__.__name__ == 'LlamaModel':
- llama_model = org_model
- shard_llama_model = sharded_model.unwrap()
- else:
- llama_model = org_model.model
- shard_llama_model = sharded_model.unwrap().model
-
- # check grad
- row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
- col_layer_for_check = ['layers[0].self_attn.o_proj']
- if stage_manager is None or stage_manager.is_first_stage():
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-6, 1e-4
- else:
- atol, rtol = 5e-3, 5e-3
- check_grad(llama_model,
- shard_llama_model,
- row_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=0,
- verbose=False)
- check_grad(llama_model,
- shard_llama_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False)
-
- # check weights after optimizer.step()
- org_optimizer.step()
- sharded_optimizer.step()
+ # check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@@ -101,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
torch.cuda.empty_cache()
@@ -128,19 +136,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
+ 'enable_all_optimization': False,
'use_lazy_init': False,
- 'precision': 'fp32',
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 1,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
def run_llama_test(test_config):
- # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'zero_stage': 1,
+ 'initial_scale': 1,
+ },
+])
+def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -150,6 +213,12 @@ def check_llama(rank, world_size, port):
run_llama_test()
+def check_llama_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_llama_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -157,5 +226,13 @@ def test_llama():
spawn(check_llama, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_llama_3d():
+ spawn(check_llama_3d, 8)
+
+
if __name__ == "__main__":
test_llama()
+ test_llama_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index 8d1154d82638..3e74859ad1a8 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -6,16 +6,19 @@
import colossalai
from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@@ -39,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ # unwrap model
+ opt_model = unwrap_model(org_model, 'OPTModel', 'model')
+ shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
+
+ row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
+ col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
+
+ # Save gradient tensors for comparison between the original model and the sharded model.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-3
+ else:
+ atol, rtol = 4e-2, 4e-2
+ row_layer_grads = get_grad_tensors_for_check(opt_model,
+ shard_opt_model,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+ col_layer_grads = get_grad_tensors_for_check(opt_model,
+ shard_opt_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@@ -50,42 +90,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- # unwrap model
- if org_model.__class__.__name__ == 'OPTModel':
- opt_model = org_model
- shard_opt_model = sharded_model.unwrap()
- else:
- opt_model = org_model.model
- shard_opt_model = sharded_model.unwrap().model
-
- # check grad
- row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
- col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
- if stage_manager is None or stage_manager.is_first_stage():
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-6, 1e-3
- else:
- atol, rtol = 3e-2, 3e-2
- check_grad(opt_model,
- shard_opt_model,
- row_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=0,
- verbose=False)
- check_grad(opt_model,
- shard_opt_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False)
-
- # check weights after optimizer.step()
- org_optimizer.step()
- sharded_optimizer.step()
+ # check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
@@ -100,6 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -123,12 +132,62 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
- 'precision': 'fp32',
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 1,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
def run_opt_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
- # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
+ clear_layout_converter()
+ torch.cuda.empty_cache()
+
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'zero_stage': 1,
+ 'initial_scale': 1,
+ },
+])
+def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@@ -144,6 +203,12 @@ def check_OPTModel(rank, world_size, port):
run_opt_test()
+def check_opt_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_opt_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -151,5 +216,13 @@ def test_OPTModel():
spawn(check_OPTModel, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_opt_3d():
+ spawn(check_opt_3d, 8)
+
+
if __name__ == '__main__':
test_OPTModel()
+ test_opt_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 066f7ee815b4..768cae0a6734 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -1,5 +1,6 @@
import pytest
import torch
+from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers
@@ -9,11 +10,13 @@
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
@@ -35,6 +38,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
+ # unwrap model
+ t5 = unwrap_model(org_model)
+ sharded_t5 = unwrap_model(sharded_model)
+
+ row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
+
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
+ row_layer_grads = get_grad_tensors_for_check(t5,
+ sharded_t5,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@@ -47,30 +76,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
- # unwrap model
- t5 = org_model
- sharded_t5 = sharded_model.unwrap()
-
- row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
-
- # check weights and gradients
+ # check weights
if test_config['precision'] == 'fp32':
- atol, rtol = 1e-5, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
- if stage_manager is None or stage_manager.is_first_stage():
- check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
-
- # check weights after optimizer.step()
- org_optimizer.step()
- sharded_optimizer.step()
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-4, 1e-3
+ atol, rtol = 5e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
torch.cuda.empty_cache()
@@ -99,17 +115,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
+ 'enable_all_optimization': False,
'use_lazy_init': False,
- 'precision': 'fp32',
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'zero_stage': 1,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
@clear_cache_before_run()
def run_t5_test(test_config):
- # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
- # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
-
- # TODO(baizhou): add test_config for flash attention & jit operator after supporting
-
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@@ -125,12 +160,49 @@ def run_t5_test(test_config):
torch.cuda.empty_cache()
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'zero_stage': 1,
+ 'initial_scale': 1,
+ },
+])
+def run_t5_3d_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ torch.cuda.empty_cache()
+
+
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
+def check_t5_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_t5_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -138,5 +210,13 @@ def test_t5():
spawn(check_t5, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_t5_3d():
+ spawn(check_t5_3d, 8)
+
+
if __name__ == "__main__":
test_t5()
+ test_t5_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py
index 18df8ef555f2..15db63bfd9da 100644
--- a/tests/test_shardformer/test_model/test_shard_vit.py
+++ b/tests/test_shardformer/test_model/test_shard_vit.py
@@ -9,11 +9,13 @@
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
- check_grad,
+ check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
+ get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
)
@@ -35,54 +37,56 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
- # check last hidden state & loss
- if stage_manager is None or stage_manager.is_last_stage():
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-5, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
-
- if org_model.__class__.__name__ == 'ViTModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
-
- check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
-
# unwrap model
- if org_model.__class__.__name__ == 'ViTModel':
- vit_model = org_model
- shard_vit_model = sharded_model.unwrap()
- else:
- vit_model = org_model.vit
- shard_vit_model = sharded_model.unwrap().vit
+ vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
+ shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
# check grad
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
- if stage_manager is None or stage_manager.is_first_stage():
+
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
- check_grad(vit_model,
- shard_vit_model,
- row_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=0,
- verbose=False)
- check_grad(vit_model,
- shard_vit_model,
- col_layer_for_check,
- tp_group,
- atol=atol,
- rtol=rtol,
- dim=1,
- verbose=False)
-
- # check weights after optimizer.step()
+ row_layer_grads = get_grad_tensors_for_check(vit_model,
+ shard_vit_model,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+ col_layer_grads = get_grad_tensors_for_check(vit_model,
+ shard_vit_model,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
+
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+
+ if org_model.__class__.__name__ == 'ViTModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+
+ # check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
@@ -97,9 +101,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
torch.cuda.empty_cache()
+#TODO: num_microbatch size = 2 inf loss
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
@@ -120,15 +128,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
- 'precision': 'fp32',
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32'
+}, {
+ 'tp_size': 2,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'zero_stage': 2,
+ 'precision': 'fp16',
+ 'initial_scale': 1
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'zero_stage': 1,
+ 'precision': 'fp16',
+ 'initial_scale': 1
}])
def run_vit_test(test_config):
- # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
- # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models
+ # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
-
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
@@ -137,12 +166,48 @@ def run_vit_test(test_config):
torch.cuda.empty_cache()
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+])
+def run_vit_3d_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ torch.cuda.empty_cache()
+
+
def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_test()
+def check_vit_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_vit_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@@ -150,5 +215,13 @@ def test_vit():
spawn(check_vit, 4)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_vit_3d():
+ spawn(check_vit_3d, 8)
+
+
if __name__ == "__main__":
test_vit()
+ test_vit_3d()
diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py
index 9b38ae07b1d6..d0c04c98f80a 100644
--- a/tests/test_shardformer/test_model/test_shard_whisper.py
+++ b/tests/test_shardformer/test_model/test_shard_whisper.py
@@ -3,6 +3,8 @@
import colossalai
from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
@@ -11,55 +13,205 @@
spawn,
)
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_all_grad_tensors,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ get_grad_tensors_for_check,
+ run_forward_backward_with_hybrid_plugin,
+)
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# check forward
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
- assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5)
-
- # do backward
- org_loss.backward()
- shard_loss.backward()
-
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
+
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
+
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
# unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model
- sharded_whisper = sharded_model.model
+ sharded_whisper = sharded_model.unwrap().model
else:
whisper = org_model
- sharded_whisper = sharded_model
+ sharded_whisper = sharded_model.unwrap()
# check grad
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
else:
- col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj']
- row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj']
- check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
- check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
+ col_layer_for_check = [
+ 'encoder.layers[0].self_attn.q_proj',
+ # 'decoder.layers[0].self_attn.q_proj'
+ ]
+ row_layer_for_check = [
+ 'encoder.layers[0].self_attn.out_proj',
+ #'decoder.layers[0].self_attn.out_proj'
+ ]
+
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 2e-4, 2e-4
+ else:
+ atol, rtol = 5e-3, 5e-3
+
+ if stage_manager is None or stage_manager.is_first_stage():
+ row_layer_grads = get_grad_tensors_for_check(whisper,
+ sharded_whisper,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1)
+ col_layer_grads = get_grad_tensors_for_check(whisper,
+ sharded_whisper,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0)
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 2e-4, 2e-4
+ else:
+ atol, rtol = 5e-3, 5e-3
+
+ if org_model.__class__.__name__ == 'WhisperModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
+
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+
+ # check weights
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-3, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(whisper,
+ sharded_whisper,
+ row_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False)
+ check_weight(whisper,
+ sharded_whisper,
+ col_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=0,
+ verbose=False)
+
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
+ torch.cuda.empty_cache()
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
+#TODO fix WhisperForConditionalGeneration enable jit fused operato
+# TODO(jianghai) fix fp16
+@parameterize(
+ 'test_config',
+ [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ },
+ {
+ 'tp_size': 1,
+ 'pp_size': 4,
+ 'num_microbatches': 4,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ },
+ # whisper is not supported fp16 for now.
+ ])
+def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn,
- enable_fused_normalization=enable_fused_normalization,
- enable_tensor_parallelism=enable_tensor_parallelism,
- enable_flash_attention=enable_flash_attention,
- enable_jit_fused=enable_jit_fused)
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+ if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification':
+ continue
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+ {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
+ 'initial_scale': 1,
+ },
+])
+def run_whisper_3d_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
torch.cuda.empty_cache()
@@ -69,12 +221,26 @@ def check_whisper(rank, world_size, port):
run_whisper_test()
+def check_whisper_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ run_whisper_3d_test()
+
+
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_whisper():
- spawn(check_whisper, 2)
+ spawn(check_whisper, 4)
+
+
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_whisper_3d():
+ spawn(check_whisper_3d, 8)
if __name__ == "__main__":
test_whisper()
+ test_whisper_3d()
diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py
index 2930552cc4e7..b7764c2f4371 100644
--- a/tests/test_utils/test_activation_checkpointing.py
+++ b/tests/test_utils/test_activation_checkpointing.py
@@ -40,7 +40,6 @@ def forward_inplace(x, weight):
return out
-@pytest.mark.gpu
@clear_cache_before_run()
@parameterize("use_reentrant", [True, False])
@parameterize("cpu_offload", [True, False])
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py
index 335be61359ed..9c3a7e2161d2 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py
+++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py
@@ -7,7 +7,7 @@
import torch
import torch.nn as nn
-import colossalai.nn as col_nn
+import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py
index 175d9ef6ceb9..03b2e4f2a9b2 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py
+++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py
@@ -7,7 +7,7 @@
import torch
import torch.nn as nn
-import colossalai.nn as col_nn
+import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
index 33cb3a65d184..cafffd0a6202 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
+++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py
@@ -7,7 +7,7 @@
import torch
import torch.nn as nn
-import colossalai.nn as col_nn
+import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py
index 73ac2dd5fe18..9b43be9e8cc5 100644
--- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py
+++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py
@@ -7,7 +7,7 @@
import torch
import torch.nn as nn
-import colossalai.nn as col_nn
+import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
diff --git a/version.txt b/version.txt
index 9e11b32fcaa9..d15723fbe8de 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.3.1
+0.3.2