diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml
index 650689498fda..a336526897e2 100644
--- a/.github/workflows/run_chatgpt_examples.yml
+++ b/.github/workflows/run_chatgpt_examples.yml
@@ -28,9 +28,8 @@ jobs:
- name: Checkout ColossalAI
uses: actions/checkout@v2
- - name: Install ColossalAI and ChatGPT
+ - name: Install ChatGPT
run: |
- pip install -e .
cd applications/Chat
pip install -v .
pip install -r examples/requirements.txt
diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml
index 47c80fc9a9fe..ec5c8ffa319f 100644
--- a/.github/workflows/run_chatgpt_unit_tests.yml
+++ b/.github/workflows/run_chatgpt_unit_tests.yml
@@ -30,9 +30,8 @@ jobs:
- name: Checkout ColossalAI
uses: actions/checkout@v2
- - name: Install ColossalAI and ChatGPT
+ - name: Install ChatGPT
run: |
- pip install -e .
cd applications/Chat
pip install -v .
pip install -r requirements-test.txt
diff --git a/README.md b/README.md
index 44e4f97f1f4e..0ddcdab741a4 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
## Latest News
+* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
@@ -50,7 +51,7 @@
Parallel Training Demo
- - LLaMA
+ - LLaMA 1/2
- GPT-3
- GPT-2
- BERT
@@ -217,8 +218,16 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
(back to top)
## Parallel Training Demo
+### LLaMA2
+
+
+
+
+- 70 billion parameter LLaMA2 model training accelerated by 195%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
-### LLaMA
+### LLaMA1
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 636b4e6772cb..2959d3fac81c 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -19,7 +19,7 @@
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
-
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
@@ -71,6 +71,42 @@ def _preprocess(sources: Sequence[str],
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
+def _preprocess_chatglm(sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Preprocess the data by tokenizing.
+ None for attention mask, ChatGLM will calculate attention mask according to input ids
+ """
+
+ labels = []
+ input_ids = []
+ for source, target in zip(sources, targets):
+ source_id = tokenizer.encode(text=source, add_special_tokens=False)
+ target_id = tokenizer.encode(text=target, add_special_tokens=False)
+ input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
+ # truncate
+ sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
+ truncate_length = max(0, len(input_id) - max_length)
+ input_id = input_id[truncate_length: ]
+ if truncate_length == len(source_id) + 1:
+ input_id = sp_token_list + input_id[1: ]
+ elif truncate_length > len(source_id) + 1:
+ input_id = sp_token_list + input_id[2: ]
+
+ context_length = input_id.index(tokenizer.bos_token_id)
+ mask_position = context_length - 1
+ label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
+
+ pad_len = max_length - len(input_id)
+ input_id = input_id + [tokenizer.pad_token_id] * pad_len
+ input_ids.append(input_id)
+ labels.append(label + [IGNORE_INDEX] * pad_len)
+ return torch.tensor(input_ids), torch.tensor(labels), None
+
+
class SFTDataset(Dataset):
"""
Dataset for sft model
@@ -94,18 +130,25 @@ def __init__(self,
data["completion"] + tokenizer.eos_token
for data in tqdm(dataset, disable=not is_rank_0())
]
-
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ else:
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ if self.attention_mask is not None:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx],
+ attention_mask=self.attention_mask[idx])
+ else:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx])
class SupervisedDataset(Dataset):
@@ -137,14 +180,22 @@ def __init__(self,
]
logger.info("Tokenizing inputs... This may take some time...")
- self.input_ids, self.labels, self.attention_mask = \
- _preprocess(sources, targets, tokenizer, max_length)
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess_chatglm(sources, targets, tokenizer, max_length)
+ else:
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
- return dict(input_ids=self.input_ids[idx],
- labels=self.labels[idx],
- attention_mask=self.attention_mask[idx])
+ if self.attention_mask is not None:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx],
+ attention_mask=self.attention_mask[idx])
+ else:
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx])
diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py
new file mode 100644
index 000000000000..373f19553fdc
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/__init__.py
@@ -0,0 +1,3 @@
+from .chatglm_actor import ChatGLMActor
+
+__all__ = ['ChatGLMActor']
\ No newline at end of file
diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py
new file mode 100644
index 000000000000..c35d994e9319
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py
@@ -0,0 +1,34 @@
+from typing import Optional
+
+import torch
+from .configuration_chatglm import ChatGLMConfig
+from .modeling_chatglm import ChatGLMForConditionalGeneration
+
+from ..base import Actor
+
+
+class ChatGLMActor(Actor):
+ """
+ ChatGLM Actor model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (ChatGLMConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+
+ do not support lora for now.
+ """
+
+ def __init__(self,
+ pretrained: str = None,
+ config: Optional[ChatGLMConfig] = None,
+ checkpoint: bool = False) -> None:
+ if pretrained is not None:
+ model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
+ elif config is not None:
+ model = ChatGLMForConditionalGeneration(config)
+ else:
+ model = ChatGLMForConditionalGeneration(ChatGLMConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank=0, lora_train_bias='none')
diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
new file mode 100644
index 000000000000..f7717f7e68b6
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
@@ -0,0 +1,446 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
+"""
+"""Tokenization classes for ChatGLM."""
+from typing import List, Optional, Union
+import os
+
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging, PaddingStrategy
+from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
+from typing import Dict
+import sentencepiece as spm
+import numpy as np
+
+logger = logging.get_logger(__name__)
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "THUDM/chatglm-6b": 2048,
+}
+
+
+class TextTokenizer:
+ def __init__(self, model_path):
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.Load(model_path)
+ self.num_tokens = self.sp.vocab_size()
+
+ def encode(self, text):
+ return self.sp.EncodeAsIds(text)
+
+ def decode(self, ids: List[int]):
+ return self.sp.DecodeIds(ids)
+
+ def tokenize(self, text):
+ return self.sp.EncodeAsPieces(text)
+
+ def convert_tokens_to_string(self, tokens):
+ return self.sp.DecodePieces(tokens)
+
+ def convert_tokens_to_ids(self, tokens):
+ return [self.sp.PieceToId(token) for token in tokens]
+
+ def convert_token_to_id(self, token):
+ return self.sp.PieceToId(token)
+
+ def convert_id_to_token(self, idx):
+ return self.sp.IdToPiece(idx)
+
+ def __len__(self):
+ return self.num_tokens
+
+
+class SPTokenizer:
+ def __init__(
+ self,
+ vocab_file,
+ num_image_tokens=20000,
+ max_blank_length=80,
+ byte_fallback=True,
+ ):
+ assert vocab_file is not None
+ self.vocab_file = vocab_file
+ self.num_image_tokens = num_image_tokens
+ self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""]
+ self.max_blank_length = max_blank_length
+ self.byte_fallback = byte_fallback
+ self.text_tokenizer = TextTokenizer(vocab_file)
+
+ def _get_text_tokenizer(self):
+ return self.text_tokenizer
+
+ @staticmethod
+ def get_blank_token(length: int):
+ assert length >= 2
+ return f"<|blank_{length}|>"
+
+ @staticmethod
+ def get_tab_token():
+ return f"<|tab|>"
+
+ @property
+ def num_text_tokens(self):
+ return self.text_tokenizer.num_tokens
+
+ @property
+ def num_tokens(self):
+ return self.num_image_tokens + self.num_text_tokens
+
+ @staticmethod
+ def _encode_whitespaces(text: str, max_len: int = 80):
+ text = text.replace("\t", SPTokenizer.get_tab_token())
+ for i in range(max_len, 1, -1):
+ text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
+ return text
+
+ def _preprocess(self, text: str, linebreak=True, whitespaces=True):
+ if linebreak:
+ text = text.replace("\n", "")
+ if whitespaces:
+ text = self._encode_whitespaces(text, max_len=self.max_blank_length)
+ return text
+
+ def encode(
+ self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
+ ) -> List[int]:
+ """
+ @param text: Text to encode.
+ @param linebreak: Whether to encode newline (\n) in text.
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
+ """
+ text = self._preprocess(text, linebreak, whitespaces)
+ if not add_dummy_prefix:
+ text = "" + text
+ tmp = self._get_text_tokenizer().encode(text)
+ tokens = [x + self.num_image_tokens for x in tmp]
+ return tokens if add_dummy_prefix else tokens[2:]
+
+ def postprocess(self, text):
+ text = text.replace("", "\n")
+ text = text.replace(SPTokenizer.get_tab_token(), "\t")
+ for i in range(2, self.max_blank_length + 1):
+ text = text.replace(self.get_blank_token(i), " " * i)
+ return text
+
+ def decode(self, text_ids: List[int]) -> str:
+ ids = [int(_id) - self.num_image_tokens for _id in text_ids]
+ ids = [_id for _id in ids if _id >= 0]
+ text = self._get_text_tokenizer().decode(ids)
+ text = self.postprocess(text)
+ return text
+
+ def decode_tokens(self, tokens: List[str]) -> str:
+ text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
+ text = self.postprocess(text)
+ return text
+
+ def tokenize(
+ self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
+ ) -> List[str]:
+ """
+ @param text: Text to encode.
+ @param linebreak: Whether to encode newline (\n) in text.
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
+ """
+ text = self._preprocess(text, linebreak, whitespaces)
+ if not add_dummy_prefix:
+ text = "" + text
+ tokens = self._get_text_tokenizer().tokenize(text)
+ return tokens if add_dummy_prefix else tokens[2:]
+
+ def __getitem__(self, x: Union[int, str]):
+ if isinstance(x, int):
+ if x < self.num_image_tokens:
+ return "".format(x)
+ else:
+ return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
+ elif isinstance(x, str):
+ if x.startswith("") and x[7:-1].isdigit():
+ return int(x[7:-1])
+ else:
+ return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
+ else:
+ raise ValueError("The key should be str or int.")
+
+
+class ChatGLMTokenizer(PreTrainedTokenizer):
+ """
+ Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = {"vocab_file": "ice_text.model"}
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=False,
+ remove_space=False,
+ bos_token='',
+ eos_token='',
+ end_token='',
+ mask_token='[MASK]',
+ gmask_token='[gMASK]',
+ padding_side="left",
+ pad_token="",
+ unk_token="",
+ num_image_tokens=20000,
+ **kwargs
+ ) -> None:
+ super().__init__(
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ padding_side=padding_side,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ end_token=end_token,
+ mask_token=mask_token,
+ gmask_token=gmask_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ num_image_tokens=num_image_tokens,
+ **kwargs
+ )
+
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.vocab_file = vocab_file
+
+ self.bos_token = bos_token
+ self.eos_token = eos_token
+ self.end_token = end_token
+ self.mask_token = mask_token
+ self.gmask_token = gmask_token
+
+ self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens)
+
+ """ Initialisation """
+
+ @property
+ def gmask_token_id(self) -> Optional[int]:
+ if self.gmask_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.gmask_token)
+
+ @property
+ def end_token_id(self) -> Optional[int]:
+ """
+ `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
+ set.
+ """
+ if self.end_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.end_token)
+
+ @property
+ def vocab_size(self):
+ """ Returns vocab size """
+ return self.sp_tokenizer.num_tokens
+
+ def get_vocab(self):
+ """ Returns vocab as a dict """
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def preprocess_text(self, inputs):
+ if self.remove_space:
+ outputs = " ".join(inputs.strip().split())
+ else:
+ outputs = inputs
+
+ if self.do_lower_case:
+ outputs = outputs.lower()
+
+ return outputs
+
+ def _tokenize(self, text, **kwargs):
+ """ Returns a tokenized string. """
+ text = self.preprocess_text(text)
+
+ seq = self.sp_tokenizer.tokenize(text)
+
+ return seq
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ return self.sp_tokenizer.decode_tokens(tokens)
+
+ def _decode(
+ self,
+ token_ids: Union[int, List[int]],
+ **kwargs
+ ) -> str:
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ if len(token_ids) == 0:
+ return ""
+ if self.pad_token_id in token_ids: # remove pad
+ token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
+ return super()._decode(token_ids, **kwargs)
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.sp_tokenizer[token]
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.sp_tokenizer[index]
+
+ def save_vocabulary(self, save_directory, filename_prefix=None):
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ filename_prefix (`str`, *optional*):
+ An optional prefix to add to the named of the saved files.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, self.vocab_files_names["vocab_file"]
+ )
+ else:
+ vocab_file = save_directory
+
+ with open(self.vocab_file, 'rb') as fin:
+ proto_str = fin.read()
+
+ with open(vocab_file, "wb") as writer:
+ writer.write(proto_str)
+
+ return (vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ gmask_id = self.sp_tokenizer[self.gmask_token]
+ eos_id = self.sp_tokenizer[self.eos_token]
+ token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
+ if token_ids_1 is not None:
+ token_ids_0 = token_ids_0 + token_ids_1
+ return token_ids_0
+
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ bos_token_id = self.sp_tokenizer[self.bos_token]
+ mask_token_id = self.sp_tokenizer[self.mask_token]
+ gmask_token_id = self.sp_tokenizer[self.gmask_token]
+ assert self.padding_side == "left"
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ seq_length = len(required_input)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if max_length is not None:
+ if "attention_mask" not in encoded_inputs:
+ if bos_token_id in required_input:
+ context_length = required_input.index(bos_token_id)
+ else:
+ context_length = seq_length
+ attention_mask = np.ones((1, seq_length, seq_length))
+ attention_mask = np.tril(attention_mask)
+ attention_mask[:, :, :context_length] = 1
+ attention_mask = np.bool_(attention_mask < 0.5)
+ encoded_inputs["attention_mask"] = attention_mask
+
+ if "position_ids" not in encoded_inputs:
+ if bos_token_id in required_input:
+ context_length = required_input.index(bos_token_id)
+ else:
+ context_length = seq_length
+ position_ids = np.arange(seq_length, dtype=np.int64)
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
+ if mask_token in required_input:
+ mask_position = required_input.index(mask_token)
+ position_ids[context_length:] = mask_position
+ block_position_ids = np.concatenate(
+ [np.zeros(context_length, dtype=np.int64),
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
+ encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+
+ if "attention_mask" in encoded_inputs:
+ encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
+ mode='constant', constant_values=True)
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ if "position_ids" in encoded_inputs:
+ encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
+ pad_width=[(0, 0), (difference, 0)])
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+
+ return encoded_inputs
\ No newline at end of file
diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
new file mode 100644
index 000000000000..d0e3f6cc63d7
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
@@ -0,0 +1,107 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/configuration_chatglm.py
+"""
+
+""" ChatGLM model configuration """
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class ChatGLMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`~ChatGLMModel`].
+ It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used
+ to control the model outputs. Read the documentation from [`PretrainedConfig`]
+ for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 150528):
+ Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~ChatGLMModel`] or
+ [`~TFChatGLMModel`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ inner_hidden_size (`int`, *optional*, defaults to 16384):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with.
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
+ layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models).
+ Example:
+
+ ```python
+ >>> from configuration_chatglm import ChatGLMConfig
+ >>> from modeling_chatglm import ChatGLMModel
+
+ >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
+ >>> configuration = ChatGLMConfig()
+
+ >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
+ >>> model = ChatGLMModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+"""
+ model_type = "chatglm"
+
+ def __init__(
+ self,
+ vocab_size=130528,
+ hidden_size=4096,
+ num_layers=28,
+ num_attention_heads=32,
+ layernorm_epsilon=1e-5,
+ use_cache=True,
+ bos_token_id=130004,
+ eos_token_id=130005,
+ mask_token_id=130000,
+ gmask_token_id=130001,
+ pad_token_id=3,
+ max_sequence_length=2048,
+ inner_hidden_size=16384,
+ position_encoding_2d=True,
+ quantization_bit=0,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs
+ ):
+ self.num_layers = num_layers
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.max_sequence_length = max_sequence_length
+ self.layernorm_epsilon = layernorm_epsilon
+ self.inner_hidden_size = inner_hidden_size
+ self.use_cache = use_cache
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.mask_token_id = mask_token_id
+ self.gmask_token_id = gmask_token_id
+ self.position_encoding_2d = position_encoding_2d
+ self.quantization_bit = quantization_bit
+ self.pre_seq_len = pre_seq_len
+ self.prefix_projection = prefix_projection
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs
+ )
\ No newline at end of file
diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
new file mode 100644
index 000000000000..77e7d0d8ea09
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
@@ -0,0 +1,1439 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py
+"""
+
+""" PyTorch ChatGLM model. """
+
+import math
+import copy
+import os
+import warnings
+import re
+import sys
+
+import torch
+import torch.utils.checkpoint
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn.utils import skip_init
+from typing import Optional, Tuple, Union, List, Callable, Dict, Any
+
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ BaseModelOutputWithPastAndCrossAttentions,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
+
+from .configuration_chatglm import ChatGLMConfig
+
+# flags required to enable jit fusion kernels
+
+if sys.platform != 'darwin':
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+ torch._C._jit_override_can_fuse_on_cpu(True)
+ torch._C._jit_override_can_fuse_on_gpu(True)
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
+_CONFIG_FOR_DOC = "ChatGLM6BConfig"
+
+CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "THUDM/chatglm-6b",
+ # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
+]
+
+
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+ scores.zero_()
+ scores[..., 5] = 5e4
+ return scores
+
+
+def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ assert (
+ pointer.shape == array.shape
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class PrefixEncoder(torch.nn.Module):
+ """
+ The torch.nn model to encode the prefix
+ Input shape: (batch-size, prefix-length)
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.prefix_projection = config.prefix_projection
+ if self.prefix_projection:
+ # Use a two-layer MLP to encode the prefix
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
+ self.trans = torch.nn.Sequential(
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
+ torch.nn.Tanh(),
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
+ )
+ else:
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
+
+ def forward(self, prefix: torch.Tensor):
+ if self.prefix_projection:
+ prefix_tokens = self.embedding(prefix)
+ past_key_values = self.trans(prefix_tokens)
+ else:
+ past_key_values = self.embedding(prefix)
+ return past_key_values
+
+
+@torch.jit.script
+def gelu_impl(x):
+ """OpenAI's gelu implementation."""
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
+ (1.0 + 0.044715 * x * x)))
+
+
+def gelu(x):
+ return gelu_impl(x)
+
+
+class RotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ inv_freq = inv_freq.half()
+ self.learnable = learnable
+ if learnable:
+ self.inv_freq = torch.nn.Parameter(inv_freq)
+ self.max_seq_len_cached = None
+ else:
+ self.register_buffer('inv_freq', inv_freq)
+ self.max_seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+ self.precision = precision
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
+ error_msgs):
+ pass
+
+ def forward(self, x, seq_dim=1, seq_len=None):
+ if seq_len is None:
+ seq_len = x.shape[seq_dim]
+ if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
+ self.max_seq_len_cached = None if self.learnable else seq_len
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ if self.precision == torch.bfloat16:
+ emb = emb.float()
+
+ # [sx, 1 (b * np), hn]
+ cos_cached = emb.cos()[:, None, :]
+ sin_cached = emb.sin()[:, None, :]
+ if self.precision == torch.bfloat16:
+ cos_cached = cos_cached.bfloat16()
+ sin_cached = sin_cached.bfloat16()
+ if self.learnable:
+ return cos_cached, sin_cached
+ self.cos_cached, self.sin_cached = cos_cached, sin_cached
+ return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
+
+ def _apply(self, fn):
+ if self.cos_cached is not None:
+ self.cos_cached = fn(self.cos_cached)
+ if self.sin_cached is not None:
+ self.sin_cached = fn(self.sin_cached)
+ return super()._apply(fn)
+
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+
+@torch.jit.script
+def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
+ # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+ return q, k
+
+
+def attention_fn(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ hidden_size_per_partition,
+ layer_id,
+ layer_past=None,
+ scaling_attention_score=True,
+ use_cache=False,
+):
+ if layer_past is not None:
+ past_key, past_value = layer_past[0], layer_past[1]
+ key_layer = torch.cat((past_key, key_layer), dim=0)
+ value_layer = torch.cat((past_value, value_layer), dim=0)
+
+ # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
+ seq_len, b, nh, hidden_size = key_layer.shape
+
+ if use_cache:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ query_key_layer_scaling_coeff = float(layer_id + 1)
+ if scaling_attention_score:
+ query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
+
+ # ===================================
+ # Raw attention scores. [b, np, s, s]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
+
+ matmul_result = torch.zeros(
+ 1, 1, 1,
+ dtype=query_layer.dtype,
+ device=query_layer.device,
+ )
+
+ matmul_result = torch.baddbmm(
+ matmul_result,
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
+ beta=0.0,
+ alpha=1.0,
+ )
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ if self.scale_mask_softmax:
+ self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
+ else:
+ if not (attention_mask == 0).all():
+ # if auto-regressive, skip
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+ dtype = attention_scores.dtype
+ attention_scores = attention_scores.float()
+ attention_scores = attention_scores * query_key_layer_scaling_coeff
+
+ attention_probs = F.softmax(attention_scores, dim=-1)
+
+ attention_probs = attention_probs.type(dtype)
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, present, attention_probs)
+
+ return outputs
+
+
+def default_init(cls, *args, **kwargs):
+ return cls(*args, **kwargs)
+
+
+class SelfAttention(torch.nn.Module):
+ def __init__(self, hidden_size, num_attention_heads,
+ layer_id, hidden_size_per_attention_head=None, bias=True,
+ params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ super(SelfAttention, self).__init__()
+
+ self.layer_id = layer_id
+ self.hidden_size = hidden_size
+ self.hidden_size_per_partition = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_attention_heads_per_partition = num_attention_heads
+ self.position_encoding_2d = position_encoding_2d
+ self.rotary_emb = RotaryEmbedding(
+ self.hidden_size // (self.num_attention_heads * 2)
+ if position_encoding_2d
+ else self.hidden_size // self.num_attention_heads,
+ base=10000,
+ precision=torch.half,
+ learnable=False,
+ )
+
+ self.scale_mask_softmax = None
+
+ if hidden_size_per_attention_head is None:
+ self.hidden_size_per_attention_head = hidden_size // num_attention_heads
+ else:
+ self.hidden_size_per_attention_head = hidden_size_per_attention_head
+
+ self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+
+ # Strided linear layer.
+ self.query_key_value = init_method(
+ torch.nn.Linear,
+ hidden_size,
+ 3 * self.inner_hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ self.dense = init_method(
+ torch.nn.Linear,
+ self.inner_hidden_size,
+ hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ @staticmethod
+ def attention_mask_func(attention_scores, attention_mask):
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+ return attention_scores
+
+ def split_tensor_along_last_dim(self, tensor, num_partitions,
+ contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ attention_mask: [(1, 1), seq_len, seq_len]
+ """
+
+ # [seq_len, batch, 3 * hidden_size]
+ mixed_raw_layer = self.query_key_value(hidden_states)
+
+ # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
+ new_tensor_shape = mixed_raw_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
+
+ # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
+ (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
+
+ if self.position_encoding_2d:
+ q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
+ k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
+ cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
+ position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
+ position_ids[:, 1, :].transpose(0, 1).contiguous()
+ q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
+ q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
+ query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
+ key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
+ else:
+ position_ids = position_ids.transpose(0, 1)
+ cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
+ # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
+ query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
+
+ # [seq_len, batch, hidden_size]
+ context_layer, present, attention_probs = attention_fn(
+ self=self,
+ query_layer=query_layer,
+ key_layer=key_layer,
+ value_layer=value_layer,
+ attention_mask=attention_mask,
+ hidden_size_per_partition=self.hidden_size_per_partition,
+ layer_id=layer_id,
+ layer_past=layer_past,
+ use_cache=use_cache
+ )
+
+ output = self.dense(context_layer)
+
+ outputs = (output, present)
+
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs # output, present, attention_probs
+
+
+class GEGLU(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activation_fn = F.gelu
+
+ def forward(self, x):
+ # dim=-1 breaks in jit for pt<1.10
+ x1, x2 = x.chunk(2, dim=(x.ndim - 1))
+ return x1 * self.activation_fn(x2)
+
+
+class GLU(torch.nn.Module):
+ def __init__(self, hidden_size, inner_hidden_size=None,
+ layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
+ super(GLU, self).__init__()
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ self.layer_id = layer_id
+ self.activation_func = activation_func
+
+ # Project to 4h.
+ self.hidden_size = hidden_size
+ if inner_hidden_size is None:
+ inner_hidden_size = 4 * hidden_size
+ self.inner_hidden_size = inner_hidden_size
+ self.dense_h_to_4h = init_method(
+ torch.nn.Linear,
+ self.hidden_size,
+ self.inner_hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+ # Project back to h.
+ self.dense_4h_to_h = init_method(
+ torch.nn.Linear,
+ self.inner_hidden_size,
+ self.hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ def forward(self, hidden_states):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ """
+
+ # [seq_len, batch, inner_hidden_size]
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+
+ output = self.dense_4h_to_h(intermediate_parallel)
+
+ return output
+
+
+class GLMBlock(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_attention_heads,
+ layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=None,
+ hidden_size_per_attention_head=None,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=torch.float,
+ num_layers=28,
+ position_encoding_2d=True,
+ empty_init=True
+ ):
+ super(GLMBlock, self).__init__()
+ # Set output layer initialization if not provided.
+
+ self.layer_id = layer_id
+
+ # Layernorm on the input data.
+ self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+ self.position_encoding_2d = position_encoding_2d
+
+ # Self attention.
+ self.attention = SelfAttention(
+ hidden_size,
+ num_attention_heads,
+ layer_id,
+ hidden_size_per_attention_head=hidden_size_per_attention_head,
+ bias=use_bias,
+ params_dtype=params_dtype,
+ position_encoding_2d=self.position_encoding_2d,
+ empty_init=empty_init
+ )
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+ self.num_layers = num_layers
+
+ # GLU
+ self.mlp = GLU(
+ hidden_size,
+ inner_hidden_size=inner_hidden_size,
+ bias=use_bias,
+ layer_id=layer_id,
+ params_dtype=params_dtype,
+ empty_init=empty_init
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ attention_mask: [(1, 1), seq_len, seq_len]
+ """
+
+ # Layer norm at the begining of the transformer layer.
+ # [seq_len, batch, hidden_size]
+ attention_input = self.input_layernorm(hidden_states)
+
+ # Self attention.
+ attention_outputs = self.attention(
+ attention_input,
+ position_ids,
+ attention_mask=attention_mask,
+ layer_id=layer_id,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions
+ )
+
+ attention_output = attention_outputs[0]
+
+ outputs = attention_outputs[1:]
+
+ # Residual connection.
+ alpha = (2 * self.num_layers) ** 0.5
+ hidden_states = attention_input * alpha + attention_output
+
+ mlp_input = self.post_attention_layernorm(hidden_states)
+
+ # MLP.
+ mlp_output = self.mlp(mlp_input)
+
+ # Second residual connection.
+ output = mlp_input * alpha + mlp_output
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+
+class ChatGLMPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and
+ a simple interface for downloading and loading pretrained models.
+ """
+
+ is_parallelizable = False
+ supports_gradient_checkpointing = True
+ config_class = ChatGLMConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["GLMBlock"]
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ return
+
+ def get_masks(self, input_ids, device):
+ batch_size, seq_length = input_ids.shape
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
+ attention_mask.tril_()
+ for i, context_length in enumerate(context_lengths):
+ attention_mask[i, :, :context_length] = 1
+ attention_mask.unsqueeze_(1)
+ attention_mask = (attention_mask < 0.5).bool()
+
+ return attention_mask
+
+ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
+ batch_size, seq_length = input_ids.shape
+ if use_gmasks is None:
+ use_gmasks = [False] * batch_size
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
+ if self.position_encoding_2d:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ for i, context_length in enumerate(context_lengths):
+ position_ids[i, context_length:] = mask_positions[i]
+ block_position_ids = [torch.cat((
+ torch.zeros(context_length, dtype=torch.long, device=device),
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
+ )) for context_length in context_lengths]
+ block_position_ids = torch.stack(block_position_ids, dim=0)
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
+ else:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ for i, context_length in enumerate(context_lengths):
+ if not use_gmasks[i]:
+ position_ids[i, context_length:] = mask_positions[i]
+
+ return position_ids
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, ChatGLMModel):
+ module.gradient_checkpointing = value
+
+
+CHATGLM_6B_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
+ usage and behavior.
+
+ Parameters:
+ config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
+ Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CHATGLM_6B_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`ChatGLM6BTokenizer`].
+ See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Selected in the range `[0, config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert *input_ids* indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.",
+ CHATGLM_6B_START_DOCSTRING,
+)
+class ChatGLMModel(ChatGLMPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well
+ as a decoder, in which case a layer of cross-attention is added between
+ the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the
+ `is_decoder` argument of the configuration set to `True`.
+ To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
+ argument and `add_cross_attention` set to `True`; an
+ `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ # recording parameters
+ self.max_sequence_length = config.max_sequence_length
+ self.hidden_size = config.hidden_size
+ self.params_dtype = torch.half
+ self.num_attention_heads = config.num_attention_heads
+ self.vocab_size = config.vocab_size
+ self.num_layers = config.num_layers
+ self.layernorm_epsilon = config.layernorm_epsilon
+ self.inner_hidden_size = config.inner_hidden_size
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
+ self.position_encoding_2d = config.position_encoding_2d
+ self.pre_seq_len = config.pre_seq_len
+ self.prefix_projection = config.prefix_projection
+
+ self.word_embeddings = init_method(
+ torch.nn.Embedding,
+ num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
+ dtype=self.params_dtype
+ )
+ self.gradient_checkpointing = False
+
+ def get_layer(layer_id):
+ return GLMBlock(
+ self.hidden_size,
+ self.num_attention_heads,
+ self.layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=self.inner_hidden_size,
+ hidden_size_per_attention_head=self.hidden_size_per_attention_head,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=self.params_dtype,
+ position_encoding_2d=self.position_encoding_2d,
+ empty_init=empty_init
+ )
+
+ self.layers = torch.nn.ModuleList(
+ [get_layer(layer_id) for layer_id in range(self.num_layers)]
+ )
+
+ # Final layer norm before output.
+ self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
+
+ if self.pre_seq_len is not None:
+ for param in self.parameters():
+ param.requires_grad = False
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+ self.dropout = torch.nn.Dropout(0.1)
+
+ # total_params = sum(p.numel() for p in self.parameters())
+ # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.word_embeddings = new_embeddings
+
+ def get_prompt(self, batch_size, device, dtype=torch.half):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.num_layers * 2,
+ self.num_attention_heads,
+ self.hidden_size // self.num_attention_heads
+ )
+ # seq_len, b, nh, hidden_size
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
+ # past_key_values = [(v[0], v[1]) for v in past_key_values]
+ return past_key_values
+
+ @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if past_key_values is None:
+ if self.pre_seq_len is not None:
+ past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
+ dtype=inputs_embeds.dtype)
+ else:
+ past_key_values = tuple([None] * len(self.layers))
+
+ if attention_mask is None:
+ attention_mask = self.get_masks(
+ input_ids,
+ device=input_ids.device
+ )
+
+
+ if position_ids is None:
+ MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
+ seqs = input_ids.tolist()
+
+ mask_positions, use_gmasks = [], []
+ for seq in seqs:
+ mask_token = gMASK if gMASK in seq else MASK
+ use_gmask = mask_token == gMASK
+ mask_positions.append(seq.index(mask_token))
+ use_gmasks.append(use_gmask)
+
+ position_ids = self.get_position_ids(
+ input_ids,
+ mask_positions=mask_positions,
+ device=input_ids.device,
+ use_gmasks=use_gmasks
+ )
+
+ if self.pre_seq_len is not None and attention_mask is not None:
+ prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
+ attention_mask.device)
+ prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
+
+ # [seq_len, batch, hidden_size]
+ hidden_states = inputs_embeds.transpose(0, 1)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if attention_mask is None:
+ attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ for i, layer in enumerate(self.layers):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ layer_past = past_key_values[i]
+
+ if self.gradient_checkpointing and self.training:
+ layer_ret = torch.utils.checkpoint.checkpoint(
+ layer,
+ hidden_states,
+ position_ids,
+ attention_mask,
+ torch.tensor(i),
+ layer_past,
+ use_cache,
+ output_attentions
+ )
+ else:
+ layer_ret = layer(
+ hidden_states,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ layer_id=torch.tensor(i),
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions
+ )
+
+ hidden_states = layer_ret[0]
+
+ if use_cache:
+ presents = presents + (layer_ret[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
+
+ # Final layer norm.
+ hidden_states = self.final_layernorm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+
+ # self.hidden_size = config.hidden_size
+ # self.params_dtype = torch.half
+ # self.vocab_size = config.vocab_size
+ self.max_sequence_length = config.max_sequence_length
+
+ self.position_encoding_2d = config.position_encoding_2d
+
+ self.transformer = ChatGLMModel(config, empty_init=empty_init)
+
+ self.lm_head = init_method(
+ nn.Linear,
+ config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=torch.half
+ )
+
+ self.config = config
+
+ self.quantized = False
+
+ if self.config.quantization_bit:
+ self.quantize(self.config.quantization_bit, empty_init=True)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ # update past_key_values
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
+ new_attention_mask = attention_mask[:, :, -1:].clone()
+ new_attention_mask[..., -1] = False
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, new_attention_mask], dim=2
+ )
+
+ # update position ids
+ if "position_ids" in model_kwargs:
+ position_ids = model_kwargs["position_ids"]
+ new_position_id = position_ids[..., -1:].clone()
+ new_position_id[:, 1, :] += 1
+ model_kwargs["position_ids"] = torch.cat(
+ [position_ids, new_position_id], dim=-1
+ )
+
+ return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past: Optional[torch.Tensor] = None,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs
+ ) -> dict:
+ batch_size, seq_length = input_ids.shape
+ MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
+ seqs = input_ids.tolist()
+ mask_positions, use_gmasks = [], []
+ for seq in seqs:
+ mask_token = gMASK if gMASK in seq else MASK
+ use_gmask = mask_token == gMASK
+ mask_positions.append(seq.index(mask_token))
+ use_gmasks.append(use_gmask)
+
+ # only last token for input_ids if past is not None
+ if past is not None or past_key_values is not None:
+ last_token = input_ids[:, -1].unsqueeze(-1)
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
+ attention_mask = attention_mask[:, :, -1:]
+ else:
+ attention_mask = None
+ if position_ids is not None:
+ position_ids = position_ids[..., -1:]
+ else:
+ context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
+ if self.position_encoding_2d:
+ position_ids = torch.tensor(
+ [[mask_position, seq_length - context_length] for mask_position, context_length in
+ zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
+ else:
+ position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
+ device=input_ids.device).unsqueeze(-1)
+
+ if past is None:
+ past = past_key_values
+ return {
+ "input_ids": last_token,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask
+ }
+ else:
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
+ logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
+ attention_mask = None
+ if attention_mask is None:
+ attention_mask = self.get_masks(
+ input_ids,
+ device=input_ids.device
+ )
+ if position_ids is None:
+ position_ids = self.get_position_ids(
+ input_ids,
+ device=input_ids.device,
+ mask_positions=mask_positions,
+ use_gmasks=use_gmasks
+ )
+
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask
+ }
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+
+ Output shares the same memory storage as `past`.
+ """
+ return tuple(
+ (
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
+ )
+ for layer_past in past
+ )
+
+ def process_response(self, response):
+ response = response.strip()
+ response = response.replace("[[训练时间]]", "2023年")
+ punkts = [
+ [",", ","],
+ ["!", "!"],
+ [":", ":"],
+ [";", ";"],
+ ["\?", "?"],
+ ]
+ for item in punkts:
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
+ return response
+
+ @torch.no_grad()
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ if not history:
+ prompt = query
+ else:
+ prompt = ""
+ for i, (old_query, response) in enumerate(history):
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ outputs = self.generate(**inputs, **gen_kwargs)
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ response = tokenizer.decode(outputs)
+ response = self.process_response(response)
+ history = history + [(query, response)]
+ return response, history
+
+ @torch.no_grad()
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ if not history:
+ prompt = query
+ else:
+ prompt = ""
+ for i, (old_query, response) in enumerate(history):
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ for outputs in self.stream_generate(**inputs, **gen_kwargs):
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ response = tokenizer.decode(outputs)
+ response = self.process_response(response)
+ new_history = history + [(query, response)]
+ yield response, new_history
+
+ @torch.no_grad()
+ def stream_generate(
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ **kwargs,
+ ):
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
+
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+ if not has_default_max_length:
+ logger.warn(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
+ UserWarning,
+ )
+
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ logits_warper = self._get_logits_warper(generation_config)
+
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ scores = None
+ while True:
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ if generation_config.do_sample:
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(probs, dim=-1)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ break
+ yield input_ids
+
+ def quantize(self, bits: int, empty_init=False, **kwargs):
+ if bits == 0:
+ return
+
+ from .quantization import quantize
+
+ if self.quantized:
+ logger.info("Already quantized.")
+ return self
+
+ self.quantized = True
+
+ self.config.quantization_bit = bits
+
+ self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
+ return self
diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py
index 0812ba165286..e4d0a970740d 100644
--- a/applications/Chat/coati/trainer/sft.py
+++ b/applications/Chat/coati/trainer/sft.py
@@ -52,9 +52,13 @@ def _train(self, epoch: int):
for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"],
- attention_mask=batch["attention_mask"],
- labels=batch["labels"])
+ if "attention_mask" in batch:
+ outputs = self.model(batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"])
+ else:
+ outputs = self.model(batch["input_ids"],
+ labels=batch["labels"])
loss = outputs.loss
loss = loss / self.accumulation_steps
diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt
index 40e6edc7ea73..5d0f9f927d17 100644
--- a/applications/Chat/examples/requirements.txt
+++ b/applications/Chat/examples/requirements.txt
@@ -1,2 +1,3 @@
pandas>=1.4.1
sentencepiece
+colossalai==0.3.1
\ No newline at end of file
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index 7585cf3ed0da..f068ea2bf5de 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -9,13 +9,15 @@
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
+from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler
@@ -58,6 +60,8 @@ def train(args):
model = LlamaActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
+ elif args.model == 'chatglm':
+ model = ChatGLMActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -81,6 +85,9 @@ def train(args):
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
+ elif args.model == 'chatglm':
+ tokenizer = ChatGLMTokenizer.from_pretrained(
+ "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -99,7 +106,6 @@ def train(args):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
-
logger = get_dist_logger()
# configure dataset
@@ -185,7 +191,7 @@ def train(args):
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
+ parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None)
diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt
index e079f8a6038d..eb1a77875acb 100644
--- a/applications/Chat/requirements-test.txt
+++ b/applications/Chat/requirements-test.txt
@@ -1 +1,2 @@
pytest
+colossalai==0.3.1
\ No newline at end of file
diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt
index af7ff67861eb..e5f5ca0932a8 100644
--- a/applications/Chat/requirements.txt
+++ b/applications/Chat/requirements.txt
@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
-colossalai>=0.2.4
+colossalai==0.3.1
torch<2.0.0, >=1.12.1
langchain
tokenizers
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
index 64ea1178cd0d..ea3c7b5851e2 100644
--- a/applications/Chat/tests/test_dataset.py
+++ b/applications/Chat/tests/test_dataset.py
@@ -11,7 +11,7 @@
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [
{
"instruction": "Provide a list of the top 10 most popular mobile games in Asia",
@@ -66,6 +66,8 @@ def make_tokenizer(model: str):
elif model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.unk_token
+ elif model == "chatglm":
+ tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
else:
raise ValueError(f"Unsupported model '{model}'")
return tokenizer
@@ -81,13 +83,19 @@ def check_content(input_ids_stripped: torch.Tensor,
elif model == "llama":
assert input_ids_stripped[0] == tokenizer.bos_token_id
input_ids_stripped = input_ids_stripped[1:]
-
+ elif model == "chatglm":
+ assert input_ids_stripped[0] == tokenizer.bos_token_id
+ assert input_ids_stripped[-1] == tokenizer.eos_token_id
+ input_ids_stripped = input_ids_stripped[1:-1]
assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
assert input_ids_stripped != tokenizer.sep_token_id
assert input_ids_stripped != tokenizer.cls_token_id
- assert input_ids_stripped != tokenizer.mask_token_id
+ if model == "chatglm":
+ assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
+ else:
+ assert input_ids_stripped != tokenizer.mask_token_id
@pytest.mark.cpu
@@ -189,7 +197,7 @@ def test_reward_dataset(model: str,
@pytest.mark.cpu
-@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024])
@@ -213,6 +221,19 @@ def test_sft_dataset(model: str,
max_length=max_length)
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ for i in range(max_dataset_size):
+ assert isinstance(sft_dataset[i], dict)
+ assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
+ input_ids = sft_dataset[i]["input_ids"]
+ labels = sft_dataset[i]["labels"]
+ assert input_ids.shape == labels.shape == torch.Size([max_length])
+
+ ignore_mask = labels == IGNORE_INDEX
+ assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
+ check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
+ return
+
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
@@ -245,4 +266,4 @@ def test_sft_dataset(model: str,
test_prompt_dataset(model="opt",
max_datasets_size=2,
- max_length=128)
+ max_length=128)
\ No newline at end of file
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
index bd6b3e8a5ad1..7b13becc3656 100644
--- a/applications/Chat/tests/test_models.py
+++ b/applications/Chat/tests/test_models.py
@@ -9,11 +9,12 @@
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
+from coati.models.chatglm import ChatGLMActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
-
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [4])
@@ -23,7 +24,8 @@
lambda: GPTActor(),
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
- lambda: OPTActor()
+ lambda: OPTActor(),
+ # lambda: ChatGLMActor(),
])
@pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64,
@@ -129,12 +131,12 @@ def test_lora(lora_rank: int,
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()),
+ lambda: (ChatGLMActor(), None, None),
])
@torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
batch_size: int,
seq_len: int):
-
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
@@ -150,20 +152,30 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
}
actor, critic, rm = models_maker()
+ if isinstance(actor, ChatGLMActor):
+ actor = actor.float()
+ tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
+ chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
+ actor_input ={
+ "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
+ "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
+ }
assert isinstance(actor, Actor)
base_actor_model = get_base_model(actor)
- assert isinstance(critic, Critic)
- base_critic_model = get_base_model(critic)
- assert isinstance(rm, RewardModel)
- base_rm_model = get_base_model(rm)
-
actor_output = actor(**actor_input)
- critic_output = critic(**critic_input)
- rm_output = rm(**rm_input)
-
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
- assert critic_output.shape == (batch_size, )
- assert rm_output.shape == (batch_size, )
+
+ if critic:
+ assert isinstance(critic, Critic)
+ base_critic_model = get_base_model(critic)
+ critic_output = critic(**critic_input)
+ assert critic_output.shape == (batch_size, )
+
+ if rm:
+ assert isinstance(rm, RewardModel)
+ base_rm_model = get_base_model(rm)
+ rm_output = rm(**rm_input)
+ assert rm_output.shape == (batch_size, )
@pytest.mark.cpu
@@ -232,4 +244,4 @@ def test_loss(batch_size: int,
batch_size=8,
seq_len=128)
- test_loss(batch_size=8, seq_len=128, num_labels=100)
+ test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 616b218b2070..6efafc56d5d5 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -17,8 +17,13 @@
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
+ load_param_groups_into_optimizer,
+ load_shard_state_dict,
+ load_states_into_optimizer,
save_param_groups,
save_state_dict,
+ sharded_optimizer_loading_epilogue,
+ unwrap_optimizer,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
@@ -126,19 +131,39 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
- super().load_sharded_optimizer(optimizer, index_file_path, prefix)
- current_rank_state_dict = optimizer.optim.state_dict()['state']
- for param_idx, state in current_rank_state_dict.items():
- for k, v in state.items():
- if isinstance(v, torch.Tensor) and k != 'step':
- padding_size = (self.coordinator.world_size -
- v.numel() % self.coordinator.world_size) % self.coordinator.world_size
- with torch.no_grad():
- v = v.flatten()
- if padding_size > 0:
- v = torch.nn.functional.pad(v, [0, padding_size])
- v_list = v.split(v.numel() // self.coordinator.world_size)
- current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = unwrap_optimizer(optimizer)
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory.')
+ id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
+
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ for shard_file in checkpoint_files:
+ state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ # shard state dict
+ for param_idx, state in state_dict.items():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+ padding_size = (self.coordinator.world_size -
+ v.numel() % self.coordinator.world_size) % self.coordinator.world_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ v_list = v.split(v.numel() // self.coordinator.world_size)
+ state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
+ load_states_into_optimizer(optimizer, state_dict, id_map)
+
+ sharded_optimizer_loading_epilogue(optimizer)
class LowLevelZeroModel(ModelWrapper):
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index 83e4bdcc863b..34210ea52162 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -78,8 +78,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
- del state_dict
- gc.collect()
sharded_optimizer_loading_epilogue(optimizer)
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 8837776aee4d..77ff7784a514 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -237,7 +237,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
- return torch.load(checkpoint_file)
+ return torch.load(checkpoint_file, map_location=torch.device('cpu'))
def load_state_dict_into_model(model: nn.Module,
@@ -297,7 +297,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
- saved_groups = torch.load(param_group_path)
+ saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
@@ -608,7 +608,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
- return torch.load(checkpoint_file_path)
+ return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 96d5902e893f..b4439ab19adf 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -307,7 +307,7 @@ def _add_to_bucket(self, param, group_id):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
- group_id != self._bucket_store.current_group_id:
+ group_id != self._bucket_store.current_group_id:
self._run_reduction()
padding_size = self._param_store.get_param_padding_size(param)
@@ -553,11 +553,9 @@ def load_state_dict(self, state_dict: Dict):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
- device = 'cpu' if self._cpu_offload else 'cuda'
- zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
+ zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
- zero_state_dict = dict()
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 945ca4080413..dda4f86a29a0 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
@@ -49,7 +50,7 @@
-
并行训练样例展示
- - LLaMA
+ - LLaMA 1/2
- GPT-3
- GPT-2
- BERT
@@ -210,7 +211,16 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
(返回顶端)
## 并行训练样例展示
-### LLaMA
+### LLaMA2
+
+
+
+
+- 700亿参数LLaMA2训练加速195%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
+
+### LLaMA1
diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md
index b64b5d29ecb8..483eae88ae32 100644
--- a/examples/language/llama2/README.md
+++ b/examples/language/llama2/README.md
@@ -1,4 +1,22 @@
-# Pretraining LLaMA-2: best practices for building LLaMA-2-like base models
+# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models
+
+### LLaMA2
+
+
+
+
+- 70 billion parameter LLaMA2 model training accelerated by 195%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
+
+### LLaMA1
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
## Dataset
@@ -73,7 +91,7 @@ Make sure master node can access all nodes (including itself) by ssh without pas
Here is details about CLI arguments:
-- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported.
+- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
@@ -105,7 +123,7 @@ Here we will show an example of how to run training
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
#### a. Running environment
-This experiment was performed on 4 computing nodes with 32 A800 GPUs in total. The nodes are
+This experiment was performed on 4 computing nodes with 32 A800 GPUs in total for LLaMA-1 65B. The nodes are
connected with RDMA and GPUs within one node are fully connected with NVLink.
#### b. Running command