From 9c1697f3d0da3dd5081e3fa9c8202a8b0f59c537 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 4 Jul 2023 14:35:55 +0800 Subject: [PATCH 01/15] [shardformer] added tests --- tests/test_shardformer/test_model/test_shard_vit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index a96fd02ae746..842fed23a263 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -56,6 +56,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_tensor_parallelism', [True, False]) def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + print(sub_model_zoo) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) From b07e0c990dbb19744f5f49d8979561ea776948ea Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 6 Jul 2023 10:59:42 +0800 Subject: [PATCH 02/15] [shardformer] vit test finish and support --- tests/test_shardformer/test_model/test_shard_vit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 842fed23a263..a96fd02ae746 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -56,7 +56,6 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_tensor_parallelism', [True, False]) def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - print(sub_model_zoo) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) From 3a28444781466dba07505348bcce6b950c742111 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 6 Jul 2023 17:54:58 +0800 Subject: [PATCH 03/15] [shardformer] chatglm ready --- colossalai/shardformer/policies/chatglm.py | 10 ++++++++++ .../test_shardformer/test_model/test_shard_chatglm.py | 0 2 files changed, 10 insertions(+) create mode 100644 colossalai/shardformer/policies/chatglm.py create mode 100644 tests/test_shardformer/test_model/test_shard_chatglm.py diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py new file mode 100644 index 000000000000..5a41b3ab361d --- /dev/null +++ b/colossalai/shardformer/policies/chatglm.py @@ -0,0 +1,10 @@ +from typing import Dict, Union + +import torch.nn as nn + +from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + + diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py new file mode 100644 index 000000000000..e69de29bb2d1 From 7f9c3a13b12094b905ba6a5f8b151d79f4a3e86e Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Fri, 7 Jul 2023 19:16:35 +0800 Subject: [PATCH 04/15] import chatglm --- =2.0 | 134 ++ tests/kit/model_zoo/transformers/chatglm.py | 0 .../chatglm2-6b/modeling_chatglm.py | 1193 +++++++++++++++++ 3 files changed, 1327 insertions(+) create mode 100644 =2.0 create mode 100644 tests/kit/model_zoo/transformers/chatglm.py create mode 100644 tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py diff --git a/=2.0 b/=2.0 new file mode 100644 index 000000000000..af47ce17aa8e --- /dev/null +++ b/=2.0 @@ -0,0 +1,134 @@ +Defaulting to user installation because normal site-packages is not writeable +Collecting protobuf + Using cached protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl (304 kB) +Requirement already satisfied: transformers==4.30.2 in /home/lclk/.local/lib/python3.9/site-packages (4.30.2) +Collecting cpm_kernels + Using cached cpm_kernels-1.0.11-py3-none-any.whl (416 kB) +Requirement already satisfied: torch in /home/lclk/.local/lib/python3.9/site-packages (2.0.0+cu118) +Collecting gradio + Using cached gradio-3.36.0-py3-none-any.whl (19.8 MB) +Collecting mdtex2html + Using cached mdtex2html-1.2.0-py3-none-any.whl (13 kB) +Collecting sentencepiece + Using cached sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB) +Collecting accelerate + Using cached accelerate-0.20.3-py3-none-any.whl (227 kB) +Requirement already satisfied: pyyaml>=5.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (6.0) +Requirement already satisfied: regex!=2019.12.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (2023.6.3) +Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.15.1) +Requirement already satisfied: packaging>=20.0 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (23.1) +Requirement already satisfied: requests in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from transformers==4.30.2) (2.25.1) +Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.13.3) +Requirement already satisfied: safetensors>=0.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.3.1) +Requirement already satisfied: filelock in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (3.12.0) +Requirement already satisfied: numpy>=1.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (1.24.3) +Requirement already satisfied: tqdm>=4.27 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (4.65.0) +Requirement already satisfied: fsspec in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (2023.6.0) +Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (4.6.3) +Requirement already satisfied: networkx in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1) +Requirement already satisfied: sympy in /home/lclk/.local/lib/python3.9/site-packages (from torch) (1.12) +Requirement already satisfied: triton==2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (2.0.0) +Requirement already satisfied: jinja2 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1.2) +Requirement already satisfied: lit in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (16.0.5.post0) +Requirement already satisfied: cmake in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (3.26.3) +Collecting aiofiles + Using cached aiofiles-23.1.0-py3-none-any.whl (14 kB) +Collecting ffmpy + Using cached ffmpy-0.3.0.tar.gz (4.8 kB) +Requirement already satisfied: pillow in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (9.5.0) +Collecting pydub + Using cached pydub-0.25.1-py2.py3-none-any.whl (32 kB) +Requirement already satisfied: pandas in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.0.2) +Collecting python-multipart + Using cached python_multipart-0.0.6-py3-none-any.whl (45 kB) +Collecting semantic-version + Using cached semantic_version-2.10.0-py2.py3-none-any.whl (15 kB) +Collecting pydantic + Using cached pydantic-2.0.2-py3-none-any.whl (359 kB) +Collecting uvicorn>=0.14.0 + Using cached uvicorn-0.22.0-py3-none-any.whl (58 kB) +Collecting mdit-py-plugins<=0.3.3 + Using cached mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB) +Requirement already satisfied: pygments>=2.12.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.15.1) +Collecting httpx + Using cached httpx-0.24.1-py3-none-any.whl (75 kB) +Collecting orjson + Using cached orjson-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (136 kB) +Collecting fastapi + Using cached fastapi-0.99.1-py3-none-any.whl (58 kB) +Collecting altair>=4.2.0 + Using cached altair-5.0.1-py3-none-any.whl (471 kB) +Collecting gradio-client>=0.2.7 + Using cached gradio_client-0.2.7-py3-none-any.whl (288 kB) +Requirement already satisfied: aiohttp in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.8.4) +Requirement already satisfied: matplotlib in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.7.1) +Collecting websockets>=10.0 + Using cached websockets-11.0.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB) +Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.2.0) +Requirement already satisfied: markupsafe in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.1.3) +Collecting toolz + Using cached toolz-0.12.0-py3-none-any.whl (55 kB) +Collecting jsonschema>=3.0 + Using cached jsonschema-4.18.0-py3-none-any.whl (81 kB) +Collecting rpds-py>=0.7.1 + Downloading rpds_py-0.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB) +Collecting referencing>=0.28.4 + Using cached referencing-0.29.1-py3-none-any.whl (25 kB) +Collecting jsonschema-specifications>=2023.03.6 + Using cached jsonschema_specifications-2023.6.1-py3-none-any.whl (17 kB) +Requirement already satisfied: attrs>=22.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from jsonschema>=3.0->altair>=4.2.0->gradio) (23.1.0) +Requirement already satisfied: mdurl~=0.1 in /home/lclk/.local/lib/python3.9/site-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2) +Collecting linkify-it-py<3,>=1 + Downloading linkify_it_py-2.0.2-py3-none-any.whl (19 kB) +Collecting uc-micro-py + Downloading uc_micro_py-1.0.2-py3-none-any.whl (6.2 kB) +Requirement already satisfied: pytz>=2020.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) +Requirement already satisfied: tzdata>=2022.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) +Requirement already satisfied: python-dateutil>=2.8.2 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2.8.2) +Requirement already satisfied: six>=1.5 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas->gradio) (1.16.0) +Requirement already satisfied: click>=7.0 in /home/lclk/.local/lib/python3.9/site-packages (from uvicorn>=0.14.0->gradio) (8.1.3) +Collecting h11>=0.8 + Downloading h11-0.14.0-py3-none-any.whl (58 kB) +Collecting latex2mathml + Downloading latex2mathml-3.76.0-py3-none-any.whl (73 kB) +Collecting markdown + Downloading Markdown-3.4.3-py3-none-any.whl (93 kB) +Requirement already satisfied: psutil in /home/lclk/.local/lib/python3.9/site-packages (from accelerate) (5.9.5) +Requirement already satisfied: multidict<7.0,>=4.5 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (6.0.4) +Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (4.0.2) +Requirement already satisfied: aiosignal>=1.1.2 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.1) +Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (3.1.0) +Requirement already satisfied: frozenlist>=1.1.1 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.3) +Requirement already satisfied: yarl<2.0,>=1.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.9.2) +Requirement already satisfied: idna>=2.0 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from yarl<2.0,>=1.0->aiohttp->gradio) (2.10) +Collecting pydantic + Downloading pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB) +Collecting starlette<0.28.0,>=0.27.0 + Downloading starlette-0.27.0-py3-none-any.whl (66 kB) +Collecting anyio<5,>=3.4.0 + Downloading anyio-3.7.1-py3-none-any.whl (80 kB) +Collecting sniffio>=1.1 + Downloading sniffio-1.3.0-py3-none-any.whl (10 kB) +Requirement already satisfied: exceptiongroup in /home/lclk/.local/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->gradio) (1.1.1) +Collecting httpcore<0.18.0,>=0.15.0 + Downloading httpcore-0.17.3-py3-none-any.whl (74 kB) +Requirement already satisfied: certifi in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from httpx->gradio) (2021.5.30) +Requirement already satisfied: importlib-metadata>=4.4 in /home/lclk/.local/lib/python3.9/site-packages (from markdown->mdtex2html) (6.7.0) +Requirement already satisfied: zipp>=0.5 in /home/lclk/.local/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown->mdtex2html) (3.15.0) +Requirement already satisfied: contourpy>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.1.0) +Requirement already satisfied: fonttools>=4.22.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (4.40.0) +Requirement already satisfied: pyparsing>=2.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (3.1.0) +Requirement already satisfied: kiwisolver>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.4.4) +Requirement already satisfied: importlib-resources>=3.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (5.12.0) +Requirement already satisfied: cycler>=0.10 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (0.11.0) +Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (1.26.6) +Requirement already satisfied: chardet<5,>=3.0.2 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (4.0.0) +Requirement already satisfied: mpmath>=0.19 in /home/lclk/.local/lib/python3.9/site-packages (from sympy->torch) (1.3.0) +Building wheels for collected packages: ffmpy + Building wheel for ffmpy (setup.py): started + Building wheel for ffmpy (setup.py): finished with status 'done' + Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4709 sha256=071cebb58ca6c6947fbc669e1d94509d6f53d1ed45d9d7fb9f060d1a342cfc18 + Stored in directory: /home/lclk/.cache/pip/wheels/91/e2/96/f676aa08bfd789328c6576cd0f1fde4a3d686703bb0c247697 +Successfully built ffmpy +Installing collected packages: sniffio, rpds-py, referencing, h11, anyio, uc-micro-py, jsonschema-specifications, httpcore, websockets, toolz, starlette, pydantic, linkify-it-py, jsonschema, httpx, uvicorn, semantic-version, python-multipart, pydub, orjson, mdit-py-plugins, markdown, latex2mathml, gradio-client, ffmpy, fastapi, altair, aiofiles, sentencepiece, protobuf, mdtex2html, gradio, cpm-kernels, accelerate +Successfully installed accelerate-0.20.3 aiofiles-23.1.0 altair-5.0.1 anyio-3.7.1 cpm-kernels-1.0.11 fastapi-0.99.1 ffmpy-0.3.0 gradio-3.36.0 gradio-client-0.2.7 h11-0.14.0 httpcore-0.17.3 httpx-0.24.1 jsonschema-4.18.0 jsonschema-specifications-2023.6.1 latex2mathml-3.76.0 linkify-it-py-2.0.2 markdown-3.4.3 mdit-py-plugins-0.3.3 mdtex2html-1.2.0 orjson-3.9.1 protobuf-4.23.4 pydantic-1.10.11 pydub-0.25.1 python-multipart-0.0.6 referencing-0.29.1 rpds-py-0.8.8 semantic-version-2.10.0 sentencepiece-0.1.99 sniffio-1.3.0 starlette-0.27.0 toolz-0.12.0 uc-micro-py-1.0.2 uvicorn-0.22.0 websockets-11.0.3 diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py new file mode 100644 index 000000000000..82163c46190f --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py @@ -0,0 +1,1193 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm2-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, **_config_to_kwargs(config) + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, + device=device, **_config_to_kwargs(config) + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, + device=input_ids.device), full_attention_mask), dim=-1) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, + dtype=config.torch_dtype, **init_kwargs) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + return response + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + if history: + prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) + else: + prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, + do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, + max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, + return_past_key_values=False, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs(tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs['attention_mask'] = attention_mask + for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, + return_past_key_values=return_past_key_values, **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, + **kwargs) + return self \ No newline at end of file From 4137e03acbce5165e6cd7b3ebb5966d9334b041a Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Fri, 7 Jul 2023 19:56:22 +0800 Subject: [PATCH 05/15] [shardformer] add test kit in model zoo for chatglm --- tests/kit/model_zoo/transformers/chatglm.py | 35 ++++ .../transformers/chatglm2_6b/MODEL_LICENSE | 33 +++ .../chatglm2_6b/configuration_chatglm.py | 59 ++++++ .../modeling_chatglm.py | 0 .../transformers/chatglm2_6b/quantization.py | 188 ++++++++++++++++++ 5 files changed, 315 insertions(+) create mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE create mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py rename tests/kit/model_zoo/transformers/{chatglm2-6b => chatglm2_6b}/modeling_chatglm.py (100%) create mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index e69de29bb2d1..bed5e4bfb429 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -0,0 +1,35 @@ +import torch +import transformers +from chatglm2_6b.configuration_chatglm import ChatGLMConfig + +from ..registry import ModelAttribute, model_zoo + +# ================================ +# Register single-sentence ChatGLM +# ================================ + +config = ChatGLMConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4) + +def data_gen(): + + input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() +loss_fn = lambda x: x.loss + +model_zoo.register(name='transformers_chatglm', + model_fn=lambda: transformers.ChatGLMModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_chatglm_model, + model_attribute=ModelAttribute(has_control_flow=True)) + diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE b/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE new file mode 100644 index 000000000000..26198b21b6b2 --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE @@ -0,0 +1,33 @@ +The ChatGLM2-6B License + +1. Definitions + +“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. + +“Software” means the ChatGLM2-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py new file mode 100644 index 000000000000..3730b4e34d86 --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py @@ -0,0 +1,59 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py similarity index 100% rename from tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py rename to tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py b/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py new file mode 100644 index 000000000000..cb95bfe82b20 --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py @@ -0,0 +1,188 @@ +from torch.nn import Linear +from torch.nn.parameter import Parameter + +import bz2 +import torch +import base64 +import ctypes +from transformers.utils import logging + +from typing import List +from functools import partial + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + "int4WeightCompression", + "int4WeightExtractionFloat", + "int4WeightExtractionHalf", + "int8WeightExtractionFloat", + "int8WeightExtractionHalf", + ], + ) +except Exception as exception: + kernels = None + logger.warning("Failed to load cpm_kernels:" + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features,))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): + assert scale_list.dtype in [torch.half, torch.bfloat16] + assert weight.dtype in [torch.int8] + if source_bit_width == 8: + return weight.to(scale_list.dtype) * scale_list[:, None] + elif source_bit_width == 4: + func = ( + kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16 + ) + else: + assert False, "Unsupported bit-width" + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(torch.nn.Module): + def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, + **kwargs): + super().__init__() + self.weight_bit_width = weight_bit_width + + shape = weight.shape + + if weight is None or empty_init: + self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device) + self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device) + else: + self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) + self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(device), requires_grad=False) + self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False) + self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, device=None): + """Replace fp16 linear with quantized linear""" + for layer in model.layers: + layer.self_attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.query_key_value.bias, + dtype=layer.self_attention.query_key_value.weight.dtype, + device=layer.self_attention.query_key_value.weight.device if device is None else device, + empty_init=empty_init + ) + layer.self_attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.dense.bias, + dtype=layer.self_attention.dense.weight.dtype, + device=layer.self_attention.dense.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_h_to_4h.bias, + dtype=layer.mlp.dense_h_to_4h.weight.dtype, + device=layer.mlp.dense_h_to_4h.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_4h_to_h.bias, + dtype=layer.mlp.dense_4h_to_h.weight.dtype, + device=layer.mlp.dense_4h_to_h.weight.device if device is None else device, + empty_init=empty_init + ) + + return model From 0d37788855a0078e53f6564172aea8ab7c5cfbd1 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 10 Jul 2023 18:55:33 +0800 Subject: [PATCH 06/15] [sharformer] add first version of policy of chatglm --- colossalai/shardformer/policies/chatglm.py | 47 ++++++++++++++++- .../test_model/test_shard_chatglm.py | 51 +++++++++++++++++++ .../test_model/test_shard_vit.py | 1 - 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 5a41b3ab361d..04dfb937f080 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -1,10 +1,55 @@ from typing import Dict, Union import torch.nn as nn +from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock -from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +import colossalai.shardformer.layer as col_nn from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] +class ChatGLMModelPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement = {}, + sub_module_replacement = [ + # SubModuleReplacementDescription( + # suffix = "self_attention.query_key_value", + # target_module = col_nn.Linear1D_Col, + # ), + # SubModuleReplacementDescription( + # suffix = "self_attention.dense", + # target_module = col_nn.Linear1D_Row, + # ) + # SubModuleReplacementDescription( + # suffix = "self_attention.core_attention.attention_dropout", + # target_module = col_nn.DropoutForParallelInput, + # ) + ],) + + + def postprocess(self): + return self.model diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index e69de29bb2d1..093b82de6b7b 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -0,0 +1,51 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, atol=1e-4, rtol=1e-4) + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if org_model.__class__.__name__ == 'ChatGLMModel': + vit_model = org_model + shard_vit_model = sharded_model + else: + chatglm_model = org_model.transformer + shard_chatglm_model = sharded_model.transformer + + # check attention grad + org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index a96fd02ae746..a8048a9bdd12 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -69,7 +69,6 @@ def check_vit(rank, world_size, port): @pytest.mark.dist -@pytest.mark.skip @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): From 8296b42127275b1a9eab15f662eb0d498104a4ca Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Wed, 12 Jul 2023 15:25:07 +0800 Subject: [PATCH 07/15] [shardformer] polish chatglm code --- colossalai/shardformer/policies/autopolicy.py | 3 + colossalai/shardformer/policies/chatglm.py | 69 +++++++++++++------ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/chatglm.py | 16 +++-- .../test_model/test_shard_chatglm.py | 55 +++++++++++++-- 5 files changed, 110 insertions(+), 34 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 77583dd77cf0..52d9bd5ebeae 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -112,6 +112,9 @@ class PolicyLocation: # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), + # ChatGLM + "tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm.ChatGLMModel": + PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), } diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 04dfb937f080..df2b32724979 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -1,7 +1,6 @@ from typing import Dict, Union import torch.nn as nn -from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock import colossalai.shardformer.layer as col_nn @@ -9,11 +8,12 @@ __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] + class ChatGLMModelPolicy(Policy): def config_sanity_check(self): pass - + def preprocess(self): # Resize embedding vocab_size = self.model.config.vocab_size @@ -24,32 +24,59 @@ def preprocess(self): self.model.resize_token_embeddings(new_vocab_size) return self.model - + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock policy = {} if self.shard_config.enable_tensor_parallelism: - policy[GLMBlock] = ModulePolicyDescription( - attribute_replacement = {}, - sub_module_replacement = [ - # SubModuleReplacementDescription( - # suffix = "self_attention.query_key_value", - # target_module = col_nn.Linear1D_Col, - # ), - # SubModuleReplacementDescription( - # suffix = "self_attention.dense", - # target_module = col_nn.Linear1D_Row, - # ) - # SubModuleReplacementDescription( - # suffix = "self_attention.core_attention.attention_dropout", - # target_module = col_nn.DropoutForParallelInput, - # ) - ],) + policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) + policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ) + ]) + + return policy def postprocess(self): return self.model - diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 39e5ef411f32..08a118e5783d 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,6 +1,7 @@ from .albert import * from .bert import * from .bloom import * +from .chatglm import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index bed5e4bfb429..8de6b74e4594 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -1,17 +1,19 @@ import torch import transformers -from chatglm2_6b.configuration_chatglm import ChatGLMConfig from ..registry import ModelAttribute, model_zoo +from .chatglm2_6b.configuration_chatglm import ChatGLMConfig +from .chatglm2_6b.modeling_chatglm import ChatGLMModel # ================================ # Register single-sentence ChatGLM # ================================ -config = ChatGLMConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4) +config = ChatGLMConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) + +config.__setattr__('original_rope', True) +config.__setattr__('use_cache', True) + def data_gen(): @@ -19,6 +21,7 @@ def data_gen(): attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) + # define output transform function output_transform_fn = lambda x: x @@ -27,9 +30,8 @@ def data_gen(): loss_fn = lambda x: x.loss model_zoo.register(name='transformers_chatglm', - model_fn=lambda: transformers.ChatGLMModel(config), + model_fn=lambda: ChatGLMModel(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_chatglm_model, model_attribute=ModelAttribute(has_control_flow=True)) - diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 093b82de6b7b..471282878995 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -16,6 +16,7 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward + def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, @@ -30,16 +31,16 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # unwrap model if org_model.__class__.__name__ == 'ChatGLMModel': - vit_model = org_model - shard_vit_model = sharded_model + chatglm_model = org_model + shard_chatglm_model = sharded_model else: chatglm_model = org_model.transformer shard_chatglm_model = sharded_model.transformer # check attention grad - org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad - shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad - shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight + org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad + shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad + shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] @@ -48,4 +49,46 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo else: all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" \ No newline at end of file + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + # check embedding weights + org_grad = chatglm_model.embedding.word_embeddings.weight.grad + shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad + shard_weight = shard_chatglm_model.embedding.word_embeddings.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_chatglm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm(): + spawn(check_chatglm, 2) + + +if __name__ == "__main__": + test_chatglm() From 425f17475de882f6a527818a992944cc6da03ddd Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 13 Jul 2023 19:51:25 +0800 Subject: [PATCH 08/15] [shardformer] polish code --- colossalai/shardformer/policies/chatglm.py | 2 +- tests/kit/model_zoo/transformers/chatglm.py | 8 +- .../chatglm2_6b/modeling_chatglm.py | 481 +++++++++--------- 3 files changed, 255 insertions(+), 236 deletions(-) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index df2b32724979..74d14183f8ea 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -16,7 +16,7 @@ def config_sanity_check(self): def preprocess(self): # Resize embedding - vocab_size = self.model.config.vocab_size + vocab_size = self.model.config.padded_vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 8de6b74e4594..251110958278 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -9,7 +9,7 @@ # Register single-sentence ChatGLM # ================================ -config = ChatGLMConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) +config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, num_attention_heads=8) config.__setattr__('original_rope', True) config.__setattr__('use_cache', True) @@ -17,8 +17,8 @@ def data_gen(): - input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -30,7 +30,7 @@ def data_gen(): loss_fn = lambda x: x.loss model_zoo.register(name='transformers_chatglm', - model_fn=lambda: ChatGLMModel(config), + model_fn=lambda: ChatGLMModel(config, empty_init=False), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_chatglm_model, diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 82163c46190f..3c8e3251399c 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -1,27 +1,23 @@ """ PyTorch ChatGLM model. """ -import math import copy -import warnings +import math import re import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint import torch.nn.functional as F +import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput from .configuration_chatglm import ChatGLMConfig @@ -49,6 +45,7 @@ def default_init(cls, *args, **kwargs): class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() @@ -70,14 +67,11 @@ def __init__(self, config: ChatGLMConfig): # Use a two-layer MLP to encode the prefix kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size) - ) + self.trans = torch.nn.Sequential(torch.nn.Linear(kv_size, config.hidden_size), torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size)) else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + self.embedding = torch.nn.Embedding( + config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2) def forward(self, prefix: torch.Tensor): if self.prefix_projection: @@ -89,9 +83,9 @@ def forward(self, prefix: torch.Tensor): def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, ) -> List[torch.Tensor]: """Split a tensor along its last dimension. @@ -117,16 +111,15 @@ def split_tensor_along_last_dim( class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): + def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000): """Enhanced Transformer with Rotary Position Embedding. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ @@ -134,7 +127,7 @@ def forward_impl( https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. """ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, dtype=dtype, device=device) @@ -150,9 +143,7 @@ def forward_impl( return cache def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) + return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device) @torch.jit.script @@ -177,6 +168,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) @@ -191,6 +183,7 @@ def forward(self, hidden_states: torch.Tensor): class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() @@ -221,7 +214,9 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, is_causal=True) else: if attention_mask is not None: @@ -243,16 +238,17 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) + matmul_input_buffer = torch.empty(output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) @@ -270,8 +266,12 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): if self.coeff is not None: attention_scores = attention_scores * self.coeff if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) + attention_mask = torch.ones(output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool) attention_mask.tril_() attention_mask = ~attention_mask if attention_mask is not None: @@ -329,20 +329,22 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + self.qkv_hidden_size = (self.projection_size + + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) + self.query_key_value = nn.Linear(config.hidden_size, + self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) + device=device, + **_config_to_kwargs(config)) self.core_attention = CoreAttention(config, self.layer_number) # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) + self.dense = nn.Linear(self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config)) def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): if self.multi_query_attention: @@ -358,9 +360,7 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, device=device, ) - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): + def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): # hidden_states: [sq, b, h] # ================================================= @@ -382,22 +382,17 @@ def forward( ], dim=-1, ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) + query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + key_layer = key_layer.view(key_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)) + value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) else: new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) @@ -419,18 +414,15 @@ def forward( if self.multi_query_attention: key_layer = key_layer.unsqueeze(-2) key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) # ================================== # core attention computation @@ -468,13 +460,11 @@ def __init__(self, config: ChatGLMConfig, device=None): self.add_bias = config.add_bias_linear # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) + self.dense_h_to_4h = nn.Linear(config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config)) def swiglu(x): x = torch.chunk(x, 2, dim=-1) @@ -483,13 +473,11 @@ def swiglu(x): self.activation_func = swiglu # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) + self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config)) def forward(self, hidden_states): # [s, b, 4hp] @@ -517,7 +505,9 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + self.input_layernorm = LayerNormFunc(config.hidden_size, + eps=config.layernorm_epsilon, + device=device, dtype=config.torch_dtype) # Self attention. @@ -525,27 +515,32 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, + eps=config.layernorm_epsilon, + device=device, dtype=config.torch_dtype) # MLP self.mlp = MLP(config, device=device) def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, ): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) + attention_output, kv_cache = self.self_attention(layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -595,7 +590,9 @@ def build_layer(layer_number): if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + self.final_layernorm = LayerNormFunc(config.hidden_size, + eps=config.layernorm_epsilon, + device=device, dtype=config.torch_dtype) self.gradient_checkpointing = False @@ -604,9 +601,13 @@ def _get_layer(self, layer_number): return self.layers[layer_number] def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, ): if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] @@ -614,8 +615,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False all_self_attentions = None @@ -626,22 +626,14 @@ def forward( layer = self._get_layer(index) if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache - ) + layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb, + kv_caches[index], use_cache) else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) + layer_ret = layer(hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache) hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) @@ -680,8 +672,8 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): if past_key_values: past_length = past_key_values[0][0].shape[0] if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) + full_attention_mask = torch.cat( + (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1) if padding_mask is not None: full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) if not past_length and padding_mask is not None: @@ -708,12 +700,10 @@ def __init__(self, config: ChatGLMConfig, device=None): self.hidden_size = config.hidden_size # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) + self.word_embeddings = nn.Embedding(config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device) self.fp32_residual_connection = config.fp32_residual_connection def forward(self, input_ids): @@ -729,6 +719,7 @@ def forward(self, input_ids): class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: @@ -745,15 +736,20 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): # Rotary positional embeddings self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) + rotary_dim = (config.hidden_size // + config.num_attention_heads if config.kv_channels is None else config.kv_channels) - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, + original_impl=config.original_rope, + device=device, dtype=config.torch_dtype) self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) + self.output_layer = init_method(nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs) self.pre_seq_len = config.pre_seq_len self.prefix_projection = config.prefix_projection if self.pre_seq_len is not None: @@ -769,33 +765,27 @@ def get_input_embeddings(self): def get_prompt(self, batch_size, device, dtype=torch.half): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels - ) + past_key_values = past_key_values.view(batch_size, self.pre_seq_len, self.num_layers * 2, + self.multi_query_group_num, self.kv_channels) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) return past_key_values def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -803,14 +793,16 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) + print(inputs_embeds) if self.pre_seq_len is not None: if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + past_key_values = self.get_prompt(batch_size=batch_size, + device=input_ids.device, dtype=inputs_embeds.dtype) if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask], dim=-1) + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], + dim=-1) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): @@ -826,9 +818,12 @@ def forward( # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) @@ -847,6 +842,7 @@ def quantize(self, weight_bit_width: int): class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__(config) @@ -859,45 +855,39 @@ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): self.quantize(self.config.quantization_bit, empty_init=True) def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) + outputs, standardize_cache_format=standardize_cache_format) # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) # update position ids if "position_ids" in model_kwargs: position_ids = model_kwargs["position_ids"] new_position_id = position_ids[..., -1:].clone() new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) model_kwargs["is_first_forward"] = False return model_kwargs - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: + def prepare_inputs_for_generation(self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs) -> dict: # only last token for input_ids if past is not None if position_ids is None: position_ids = self.get_position_ids(input_ids, device=input_ids.device) @@ -913,18 +903,18 @@ def prepare_inputs_for_generation( } def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -973,9 +963,8 @@ def forward( ) @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], + beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -983,13 +972,10 @@ def _reorder_cache( Output shares the same memory storage as `past`. """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) def process_response(self, response): response = response.strip() @@ -1015,15 +1001,31 @@ def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, st return inputs @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, - do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): + def chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 8192, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs + } inputs = self.build_inputs(tokenizer, query, history=history) outputs = self.generate(**inputs, **gen_kwargs) outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] @@ -1033,16 +1035,31 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max return response, history @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, - max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - return_past_key_values=False, **kwargs): + def stream_chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 8192, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs + } if past_key_values is None and not return_past_key_values: inputs = self.build_inputs(tokenizer, query, history=history) else: @@ -1055,8 +1072,10 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No attention_mask = inputs.attention_mask attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) inputs['attention_mask'] = attention_mask - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - return_past_key_values=return_past_key_values, **gen_kwargs): + for outputs in self.stream_generate(**inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs): if return_past_key_values: outputs, past_key_values = outputs outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] @@ -1071,14 +1090,14 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No @torch.no_grad() def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, ): batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] @@ -1112,11 +1131,9 @@ def stream_generate( if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) + logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`.") # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -1130,9 +1147,8 @@ def stream_generate( logits_processor=logits_processor, ) - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) + stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, + stopping_criteria=stopping_criteria) logits_warper = self._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) @@ -1162,9 +1178,9 @@ def stream_generate( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) + model_kwargs = self._update_model_kwargs_for_generation(outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) if return_past_key_values: yield input_ids, outputs.past_key_values @@ -1188,6 +1204,9 @@ def quantize(self, bits: int, empty_init=False, device=None, **kwargs): self.config.quantization_bit = bits - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, + self.transformer.encoder = quantize(self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, **kwargs) - return self \ No newline at end of file + return self From ee66714d57f9e03716d5871a53fc922f26085f56 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Fri, 14 Jul 2023 18:10:52 +0800 Subject: [PATCH 09/15] [shardformer] support chatglm without layernorm --- colossalai/shardformer/policies/chatglm.py | 7 ++++++- tests/kit/model_zoo/transformers/chatglm.py | 2 +- .../model_zoo/transformers/chatglm2_6b/modeling_chatglm.py | 7 +++---- tests/test_shardformer/test_model/test_shard_chatglm.py | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 74d14183f8ea..b20967ec34d2 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -44,7 +44,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attention.projection_size": - self.model.config.kv_channels * self.model.config.num_attention_heads // + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // self.shard_config.tensor_parallel_size, "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 251110958278..4ac89db53637 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -9,7 +9,7 @@ # Register single-sentence ChatGLM # ================================ -config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, num_attention_heads=8) +config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, num_attention_heads=8, rmsnorm=False) config.__setattr__('original_rope', True) config.__setattr__('use_cache', True) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 3c8e3251399c..9dd1c660e4fd 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -320,7 +320,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.layer_number = max(1, layer_number) self.projection_size = config.kv_channels * config.num_attention_heads - + self.hidden_size = config.hidden_size # Per attention head and per partition values. self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads @@ -331,7 +331,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) - self.query_key_value = nn.Linear(config.hidden_size, + self.query_key_value = nn.Linear(self.hidden_size, self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, device=device, @@ -341,7 +341,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): # Output. self.dense = nn.Linear(self.projection_size, - config.hidden_size, + self.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config)) @@ -793,7 +793,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) - print(inputs_embeds) if self.pre_seq_len is not None: if past_key_values is None: diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 471282878995..bc4ec94b8661 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -21,7 +21,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, atol=1e-4, rtol=1e-4) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) # do backward org_loss.backward() shard_loss.backward() From 5c82e62808da0d1e0611bb60a32184abb7946da5 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Fri, 14 Jul 2023 18:53:47 +0800 Subject: [PATCH 10/15] [shardformer] chatglm shard without mlp sharding --- colossalai/shardformer/policies/chatglm.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index b20967ec34d2..760656e83f73 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -71,14 +71,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: suffix="self_attention.core_attention.attention_dropout", target_module=col_nn.DropoutForParallelInput, ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=col_nn.Linear1D_Row, - ) ]) return policy From 3ed27a209fdd4a1d6d83f71c1fd23c8898e76a50 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 17 Jul 2023 15:10:15 +0800 Subject: [PATCH 11/15] [shardformer] delete some file --- =2.0 | 134 ------------- tests/kit/model_zoo/transformers/chatglm.py | 13 +- .../transformers/chatglm2_6b/MODEL_LICENSE | 33 --- .../chatglm2_6b/modeling_chatglm.py | 35 ++++ .../transformers/chatglm2_6b/quantization.py | 188 ------------------ 5 files changed, 42 insertions(+), 361 deletions(-) delete mode 100644 =2.0 delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py diff --git a/=2.0 b/=2.0 deleted file mode 100644 index af47ce17aa8e..000000000000 --- a/=2.0 +++ /dev/null @@ -1,134 +0,0 @@ -Defaulting to user installation because normal site-packages is not writeable -Collecting protobuf - Using cached protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl (304 kB) -Requirement already satisfied: transformers==4.30.2 in /home/lclk/.local/lib/python3.9/site-packages (4.30.2) -Collecting cpm_kernels - Using cached cpm_kernels-1.0.11-py3-none-any.whl (416 kB) -Requirement already satisfied: torch in /home/lclk/.local/lib/python3.9/site-packages (2.0.0+cu118) -Collecting gradio - Using cached gradio-3.36.0-py3-none-any.whl (19.8 MB) -Collecting mdtex2html - Using cached mdtex2html-1.2.0-py3-none-any.whl (13 kB) -Collecting sentencepiece - Using cached sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB) -Collecting accelerate - Using cached accelerate-0.20.3-py3-none-any.whl (227 kB) -Requirement already satisfied: pyyaml>=5.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (6.0) -Requirement already satisfied: regex!=2019.12.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (2023.6.3) -Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.15.1) -Requirement already satisfied: packaging>=20.0 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (23.1) -Requirement already satisfied: requests in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from transformers==4.30.2) (2.25.1) -Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.13.3) -Requirement already satisfied: safetensors>=0.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.3.1) -Requirement already satisfied: filelock in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (3.12.0) -Requirement already satisfied: numpy>=1.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (1.24.3) -Requirement already satisfied: tqdm>=4.27 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (4.65.0) -Requirement already satisfied: fsspec in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (2023.6.0) -Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (4.6.3) -Requirement already satisfied: networkx in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1) -Requirement already satisfied: sympy in /home/lclk/.local/lib/python3.9/site-packages (from torch) (1.12) -Requirement already satisfied: triton==2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (2.0.0) -Requirement already satisfied: jinja2 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1.2) -Requirement already satisfied: lit in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (16.0.5.post0) -Requirement already satisfied: cmake in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (3.26.3) -Collecting aiofiles - Using cached aiofiles-23.1.0-py3-none-any.whl (14 kB) -Collecting ffmpy - Using cached ffmpy-0.3.0.tar.gz (4.8 kB) -Requirement already satisfied: pillow in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (9.5.0) -Collecting pydub - Using cached pydub-0.25.1-py2.py3-none-any.whl (32 kB) -Requirement already satisfied: pandas in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.0.2) -Collecting python-multipart - Using cached python_multipart-0.0.6-py3-none-any.whl (45 kB) -Collecting semantic-version - Using cached semantic_version-2.10.0-py2.py3-none-any.whl (15 kB) -Collecting pydantic - Using cached pydantic-2.0.2-py3-none-any.whl (359 kB) -Collecting uvicorn>=0.14.0 - Using cached uvicorn-0.22.0-py3-none-any.whl (58 kB) -Collecting mdit-py-plugins<=0.3.3 - Using cached mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB) -Requirement already satisfied: pygments>=2.12.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.15.1) -Collecting httpx - Using cached httpx-0.24.1-py3-none-any.whl (75 kB) -Collecting orjson - Using cached orjson-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (136 kB) -Collecting fastapi - Using cached fastapi-0.99.1-py3-none-any.whl (58 kB) -Collecting altair>=4.2.0 - Using cached altair-5.0.1-py3-none-any.whl (471 kB) -Collecting gradio-client>=0.2.7 - Using cached gradio_client-0.2.7-py3-none-any.whl (288 kB) -Requirement already satisfied: aiohttp in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.8.4) -Requirement already satisfied: matplotlib in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.7.1) -Collecting websockets>=10.0 - Using cached websockets-11.0.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB) -Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.2.0) -Requirement already satisfied: markupsafe in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.1.3) -Collecting toolz - Using cached toolz-0.12.0-py3-none-any.whl (55 kB) -Collecting jsonschema>=3.0 - Using cached jsonschema-4.18.0-py3-none-any.whl (81 kB) -Collecting rpds-py>=0.7.1 - Downloading rpds_py-0.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB) -Collecting referencing>=0.28.4 - Using cached referencing-0.29.1-py3-none-any.whl (25 kB) -Collecting jsonschema-specifications>=2023.03.6 - Using cached jsonschema_specifications-2023.6.1-py3-none-any.whl (17 kB) -Requirement already satisfied: attrs>=22.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from jsonschema>=3.0->altair>=4.2.0->gradio) (23.1.0) -Requirement already satisfied: mdurl~=0.1 in /home/lclk/.local/lib/python3.9/site-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2) -Collecting linkify-it-py<3,>=1 - Downloading linkify_it_py-2.0.2-py3-none-any.whl (19 kB) -Collecting uc-micro-py - Downloading uc_micro_py-1.0.2-py3-none-any.whl (6.2 kB) -Requirement already satisfied: pytz>=2020.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) -Requirement already satisfied: tzdata>=2022.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) -Requirement already satisfied: python-dateutil>=2.8.2 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2.8.2) -Requirement already satisfied: six>=1.5 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas->gradio) (1.16.0) -Requirement already satisfied: click>=7.0 in /home/lclk/.local/lib/python3.9/site-packages (from uvicorn>=0.14.0->gradio) (8.1.3) -Collecting h11>=0.8 - Downloading h11-0.14.0-py3-none-any.whl (58 kB) -Collecting latex2mathml - Downloading latex2mathml-3.76.0-py3-none-any.whl (73 kB) -Collecting markdown - Downloading Markdown-3.4.3-py3-none-any.whl (93 kB) -Requirement already satisfied: psutil in /home/lclk/.local/lib/python3.9/site-packages (from accelerate) (5.9.5) -Requirement already satisfied: multidict<7.0,>=4.5 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (6.0.4) -Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (4.0.2) -Requirement already satisfied: aiosignal>=1.1.2 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.1) -Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (3.1.0) -Requirement already satisfied: frozenlist>=1.1.1 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.3) -Requirement already satisfied: yarl<2.0,>=1.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.9.2) -Requirement already satisfied: idna>=2.0 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from yarl<2.0,>=1.0->aiohttp->gradio) (2.10) -Collecting pydantic - Downloading pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB) -Collecting starlette<0.28.0,>=0.27.0 - Downloading starlette-0.27.0-py3-none-any.whl (66 kB) -Collecting anyio<5,>=3.4.0 - Downloading anyio-3.7.1-py3-none-any.whl (80 kB) -Collecting sniffio>=1.1 - Downloading sniffio-1.3.0-py3-none-any.whl (10 kB) -Requirement already satisfied: exceptiongroup in /home/lclk/.local/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->gradio) (1.1.1) -Collecting httpcore<0.18.0,>=0.15.0 - Downloading httpcore-0.17.3-py3-none-any.whl (74 kB) -Requirement already satisfied: certifi in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from httpx->gradio) (2021.5.30) -Requirement already satisfied: importlib-metadata>=4.4 in /home/lclk/.local/lib/python3.9/site-packages (from markdown->mdtex2html) (6.7.0) -Requirement already satisfied: zipp>=0.5 in /home/lclk/.local/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown->mdtex2html) (3.15.0) -Requirement already satisfied: contourpy>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.1.0) -Requirement already satisfied: fonttools>=4.22.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (4.40.0) -Requirement already satisfied: pyparsing>=2.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (3.1.0) -Requirement already satisfied: kiwisolver>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.4.4) -Requirement already satisfied: importlib-resources>=3.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (5.12.0) -Requirement already satisfied: cycler>=0.10 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (0.11.0) -Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (1.26.6) -Requirement already satisfied: chardet<5,>=3.0.2 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (4.0.0) -Requirement already satisfied: mpmath>=0.19 in /home/lclk/.local/lib/python3.9/site-packages (from sympy->torch) (1.3.0) -Building wheels for collected packages: ffmpy - Building wheel for ffmpy (setup.py): started - Building wheel for ffmpy (setup.py): finished with status 'done' - Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4709 sha256=071cebb58ca6c6947fbc669e1d94509d6f53d1ed45d9d7fb9f060d1a342cfc18 - Stored in directory: /home/lclk/.cache/pip/wheels/91/e2/96/f676aa08bfd789328c6576cd0f1fde4a3d686703bb0c247697 -Successfully built ffmpy -Installing collected packages: sniffio, rpds-py, referencing, h11, anyio, uc-micro-py, jsonschema-specifications, httpcore, websockets, toolz, starlette, pydantic, linkify-it-py, jsonschema, httpx, uvicorn, semantic-version, python-multipart, pydub, orjson, mdit-py-plugins, markdown, latex2mathml, gradio-client, ffmpy, fastapi, altair, aiofiles, sentencepiece, protobuf, mdtex2html, gradio, cpm-kernels, accelerate -Successfully installed accelerate-0.20.3 aiofiles-23.1.0 altair-5.0.1 anyio-3.7.1 cpm-kernels-1.0.11 fastapi-0.99.1 ffmpy-0.3.0 gradio-3.36.0 gradio-client-0.2.7 h11-0.14.0 httpcore-0.17.3 httpx-0.24.1 jsonschema-4.18.0 jsonschema-specifications-2023.6.1 latex2mathml-3.76.0 linkify-it-py-2.0.2 markdown-3.4.3 mdit-py-plugins-0.3.3 mdtex2html-1.2.0 orjson-3.9.1 protobuf-4.23.4 pydantic-1.10.11 pydub-0.25.1 python-multipart-0.0.6 referencing-0.29.1 rpds-py-0.8.8 semantic-version-2.10.0 sentencepiece-0.1.99 sniffio-1.3.0 starlette-0.27.0 toolz-0.12.0 uc-micro-py-1.0.2 uvicorn-0.22.0 websockets-11.0.3 diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 4ac89db53637..1408babede64 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -9,14 +9,8 @@ # Register single-sentence ChatGLM # ================================ -config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, num_attention_heads=8, rmsnorm=False) - -config.__setattr__('original_rope', True) -config.__setattr__('use_cache', True) - def data_gen(): - input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -28,6 +22,13 @@ def data_gen(): # define loss function loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() loss_fn = lambda x: x.loss +config = ChatGLMConfig(num_layers=1, + padded_vocab_size=65024, + hidden_size=64, + num_attention_heads=8, + rmsnorm=False, + original_rope=True, + use_cache=True) model_zoo.register(name='transformers_chatglm', model_fn=lambda: ChatGLMModel(config, empty_init=False), diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE b/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE deleted file mode 100644 index 26198b21b6b2..000000000000 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE +++ /dev/null @@ -1,33 +0,0 @@ -The ChatGLM2-6B License - -1. Definitions - -“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. - -“Software” means the ChatGLM2-6B model parameters made available under this license. - -2. License Grant - -Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -3. Restriction - -You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. - -You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. - -4. Disclaimer - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -5. Limitation of Liability - -EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - -6. Dispute Resolution - -This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. - -Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 9dd1c660e4fd..4bacdaa542c0 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -1,3 +1,38 @@ +""" +The ChatGLM2-6B License + +1. Definitions + +“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. + +“Software” means the ChatGLM2-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. +""" """ PyTorch ChatGLM model. """ import copy diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py b/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py deleted file mode 100644 index cb95bfe82b20..000000000000 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py +++ /dev/null @@ -1,188 +0,0 @@ -from torch.nn import Linear -from torch.nn.parameter import Parameter - -import bz2 -import torch -import base64 -import ctypes -from transformers.utils import logging - -from typing import List -from functools import partial - -logger = logging.get_logger(__name__) - -try: - from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up - - class Kernel: - def __init__(self, code: bytes, function_names: List[str]): - self.code = code - self._function_names = function_names - self._cmodule = LazyKernelCModule(self.code) - - for name in self._function_names: - setattr(self, name, KernelFunction(self._cmodule, name)) - - quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" - - kernels = Kernel( - bz2.decompress(base64.b64decode(quantization_code)), - [ - "int4WeightCompression", - "int4WeightExtractionFloat", - "int4WeightExtractionHalf", - "int8WeightExtractionFloat", - "int8WeightExtractionHalf", - ], - ) -except Exception as exception: - kernels = None - logger.warning("Failed to load cpm_kernels:" + str(exception)) - - -class W8A16Linear(torch.autograd.Function): - @staticmethod - def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): - ctx.inp_shape = inp.size() - ctx.weight_bit_width = weight_bit_width - out_features = quant_w.size(0) - inp = inp.contiguous().view(-1, inp.size(-1)) - weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) - ctx.weight_shape = weight.size() - output = inp.mm(weight.t()) - ctx.save_for_backward(inp, quant_w, scale_w) - return output.view(*(ctx.inp_shape[:-1] + (out_features,))) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - inp, quant_w, scale_w = ctx.saved_tensors - weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) - grad_output = grad_output.contiguous().view(-1, weight.size(0)) - grad_input = grad_output.mm(weight) - grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None - - -def compress_int4_weight(weight: torch.Tensor): # (n, m) - with torch.cuda.device(weight.device): - n, m = weight.size(0), weight.size(1) - assert m % 2 == 0 - m = m // 2 - out = torch.empty(n, m, dtype=torch.int8, device="cuda") - stream = torch.cuda.current_stream() - - gridDim = (n, 1, 1) - blockDim = (min(round_up(m, 32), 1024), 1, 1) - - kernels.int4WeightCompression( - gridDim, - blockDim, - 0, - stream, - [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], - ) - return out - - -def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): - assert scale_list.dtype in [torch.half, torch.bfloat16] - assert weight.dtype in [torch.int8] - if source_bit_width == 8: - return weight.to(scale_list.dtype) * scale_list[:, None] - elif source_bit_width == 4: - func = ( - kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16 - ) - else: - assert False, "Unsupported bit-width" - - with torch.cuda.device(weight.device): - n, m = weight.size(0), weight.size(1) - out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda") - stream = torch.cuda.current_stream() - - gridDim = (n, 1, 1) - blockDim = (min(round_up(m, 32), 1024), 1, 1) - - func( - gridDim, - blockDim, - 0, - stream, - [ - ctypes.c_void_p(weight.data_ptr()), - ctypes.c_void_p(scale_list.data_ptr()), - ctypes.c_void_p(out.data_ptr()), - ctypes.c_int32(n), - ctypes.c_int32(m), - ], - ) - return out - - -class QuantizedLinear(torch.nn.Module): - def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, - **kwargs): - super().__init__() - self.weight_bit_width = weight_bit_width - - shape = weight.shape - - if weight is None or empty_init: - self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device) - self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device) - else: - self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) - self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) - if weight_bit_width == 4: - self.weight = compress_int4_weight(self.weight) - - self.weight = Parameter(self.weight.to(device), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False) - self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None - - def forward(self, input): - output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) - if self.bias is not None: - output = output + self.bias - return output - - -def quantize(model, weight_bit_width, empty_init=False, device=None): - """Replace fp16 linear with quantized linear""" - for layer in model.layers: - layer.self_attention.query_key_value = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()), - bias=layer.self_attention.query_key_value.bias, - dtype=layer.self_attention.query_key_value.weight.dtype, - device=layer.self_attention.query_key_value.weight.device if device is None else device, - empty_init=empty_init - ) - layer.self_attention.dense = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()), - bias=layer.self_attention.dense.bias, - dtype=layer.self_attention.dense.weight.dtype, - device=layer.self_attention.dense.weight.device if device is None else device, - empty_init=empty_init - ) - layer.mlp.dense_h_to_4h = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), - bias=layer.mlp.dense_h_to_4h.bias, - dtype=layer.mlp.dense_h_to_4h.weight.dtype, - device=layer.mlp.dense_h_to_4h.weight.device if device is None else device, - empty_init=empty_init - ) - layer.mlp.dense_4h_to_h = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), - bias=layer.mlp.dense_4h_to_h.bias, - dtype=layer.mlp.dense_4h_to_h.weight.dtype, - device=layer.mlp.dense_4h_to_h.weight.device if device is None else device, - empty_init=empty_init - ) - - return model From 92fff75d8058c3ba9087ae9d7b5e7cf3987ecf57 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 17 Jul 2023 19:47:57 +0800 Subject: [PATCH 12/15] [shardformer] ChatGLM support layernorm sharding --- colossalai/shardformer/policies/chatglm.py | 21 +++++++++++++++++-- .../chatglm2_6b/modeling_chatglm.py | 5 ++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 760656e83f73..934b99b83ea1 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -46,8 +46,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "self_attention.projection_size": (self.model.config.kv_channels * self.model.config.num_attention_heads) // self.shard_config.tensor_parallel_size, - "self_attention.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.qkv_hidden_size": (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // self.shard_config.tensor_parallel_size, @@ -72,6 +70,25 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.DropoutForParallelInput, ), ]) + # optimization configuration + if self.shard_config.enable_fused_normalization: + if not self.model.config.rmsnorm: + + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), + SubModuleReplacementDescription(suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm) + ], + policy=policy, + target_key=GLMBlock) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="encoder.final_layernorm", + target_module=col_nn.FusedLayerNorm) + ], + policy=policy, + target_key=ChatGLMModel) return policy diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 4bacdaa542c0..aed8de834b19 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -355,7 +355,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.layer_number = max(1, layer_number) self.projection_size = config.kv_channels * config.num_attention_heads - self.hidden_size = config.hidden_size # Per attention head and per partition values. self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads @@ -366,7 +365,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) - self.query_key_value = nn.Linear(self.hidden_size, + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, device=device, @@ -376,7 +375,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): # Output. self.dense = nn.Linear(self.projection_size, - self.hidden_size, + config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config)) From 1d52b2b8d9371bbb8e3fd6e9dfbe8205315eba43 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 18 Jul 2023 12:33:12 +0800 Subject: [PATCH 13/15] [shardformer] register without auto policy --- colossalai/shardformer/policies/autopolicy.py | 3 --- .../test_model/test_shard_chatglm.py | 15 ++++++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 52d9bd5ebeae..77583dd77cf0 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -112,9 +112,6 @@ class PolicyLocation: # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), - # ChatGLM - "tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), } diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index bc4ec94b8661..2cdf5da2e6da 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -1,3 +1,4 @@ +import copy import os import pytest @@ -5,6 +6,8 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -72,7 +75,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + # create new model + org_model = model_fn().cuda() + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + if name == "transformers_chatglm": + sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From ae17a32a3ac67436d773a0a8ac91baa9778623b2 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Wed, 19 Jul 2023 11:39:59 +0800 Subject: [PATCH 14/15] [shardformer] pre-commit check files --- .../chatglm2_6b/modeling_chatglm.py | 455 +++++++++++------- 1 file changed, 291 insertions(+), 164 deletions(-) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index aed8de834b19..bae6d425878d 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -58,7 +58,7 @@ # flags required to enable jit fusion kernels -if sys.platform != 'darwin': +if sys.platform != "darwin": torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) @@ -100,13 +100,18 @@ def __init__(self, config: ChatGLMConfig): self.prefix_projection = config.prefix_projection if self.prefix_projection: # Use a two-layer MLP to encode the prefix - kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2) self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential(torch.nn.Linear(kv_size, config.hidden_size), torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size)) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size), + ) else: self.embedding = torch.nn.Embedding( - config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2, + ) def forward(self, prefix: torch.Tensor): if self.prefix_projection: @@ -154,7 +159,14 @@ def __init__(self, dim, original_impl=False, device=None, dtype=None): self.dim = dim self.original_impl = original_impl - def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000): + def forward_impl( + self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000, + ): """Enhanced Transformer with Rotary Position Embedding. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ @@ -178,7 +190,12 @@ def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: to return cache def forward(self, max_seq_len, offset=0): - return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device) + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device, + ) @torch.jit.script @@ -232,7 +249,7 @@ def __init__(self, config: ChatGLMConfig, layer_number): # Per attention head and per partition values. self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads) self.num_attention_heads_per_partition = config.num_attention_heads coeff = None @@ -245,7 +262,7 @@ def __init__(self, config: ChatGLMConfig, layer_number): self.attention_dropout = torch.nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) + pytorch_major_version = int(torch.__version__.split(".")[0]) if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: @@ -265,7 +282,12 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # Raw attention scores # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) @@ -273,11 +295,13 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty(output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=query_layer.device) + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( @@ -300,13 +324,15 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): attention_scores = attention_scores.float() if self.coeff is not None: attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], - 1, - output_size[2], - output_size[3], - device=attention_scores.device, - dtype=torch.bool) + if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]): + attention_mask = torch.ones( + output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool, + ) attention_mask.tril_() attention_mask = ~attention_mask if attention_mask is not None: @@ -325,7 +351,12 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] @@ -356,7 +387,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) self.num_attention_heads_per_partition = config.num_attention_heads self.multi_query_attention = config.multi_query_attention @@ -365,20 +396,24 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) - self.query_key_value = nn.Linear(config.hidden_size, - self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, - **_config_to_kwargs(config)) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config), + ) self.core_attention = CoreAttention(config, self.layer_number) # Output. - self.dense = nn.Linear(self.projection_size, - config.hidden_size, - bias=config.add_bias_linear, - device=device, - **_config_to_kwargs(config)) + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config), + ) def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): if self.multi_query_attention: @@ -394,7 +429,14 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, device=device, ) - def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): # hidden_states: [sq, b, h] # ================================================= @@ -416,16 +458,23 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, ], dim=-1, ) - query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head)) - key_layer = key_layer.view(key_layer.size()[:-1] + - (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)) - value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head)) + query_layer = query_layer.view(query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + key_layer = key_layer.view(key_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.view(value_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) @@ -448,15 +497,28 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, if self.multi_query_attention: key_layer = key_layer.unsqueeze(-2) key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) - key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head)) + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) - value_layer = value_layer.contiguous().view(value_layer.size()[:2] + - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head)) + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) # ================================== # core attention computation @@ -494,11 +556,13 @@ def __init__(self, config: ChatGLMConfig, device=None): self.add_bias = config.add_bias_linear # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear(config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config)) + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) def swiglu(x): x = torch.chunk(x, 2, dim=-1) @@ -507,11 +571,13 @@ def swiglu(x): self.activation_func = swiglu # Project back to h. - self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config)) + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) def forward(self, hidden_states): # [s, b, 4hp] @@ -533,26 +599,30 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(GLMBlock, self).__init__() self.layer_number = layer_number - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm) self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, - eps=config.layernorm_epsilon, - device=device, - dtype=config.torch_dtype) + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) # Self attention. self.self_attention = SelfAttention(config, layer_number, device=device) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, - eps=config.layernorm_epsilon, - device=device, - dtype=config.torch_dtype) + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) # MLP self.mlp = MLP(config, device=device) @@ -570,11 +640,13 @@ def forward( # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output, kv_cache = self.self_attention(layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache) + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) # Residual connection. if self.apply_residual_connection_post_layernorm: @@ -624,10 +696,12 @@ def build_layer(layer_number): if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, - eps=config.layernorm_epsilon, - device=device, - dtype=config.torch_dtype) + self.final_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) self.gradient_checkpointing = False @@ -660,14 +734,22 @@ def forward( layer = self._get_layer(index) if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb, - kv_caches[index], use_cache) + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + ) else: - layer_ret = layer(hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache) + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache, + ) hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) @@ -707,7 +789,12 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): past_length = past_key_values[0][0].shape[0] if past_length: full_attention_mask = torch.cat( - (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1) + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) if padding_mask is not None: full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) if not past_length and padding_mask is not None: @@ -718,7 +805,7 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): def get_position_ids(self, input_ids, device): batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)) return position_ids def _set_gradient_checkpointing(self, module, value=False): @@ -734,10 +821,12 @@ def __init__(self, config: ChatGLMConfig, device=None): self.hidden_size = config.hidden_size # Word embeddings (parallel). - self.word_embeddings = nn.Embedding(config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device) + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device, + ) self.fp32_residual_connection = config.fp32_residual_connection def forward(self, input_ids): @@ -773,17 +862,21 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): rotary_dim = (config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels) - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, - original_impl=config.original_rope, - device=device, - dtype=config.torch_dtype) + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + original_impl=config.original_rope, + device=device, + dtype=config.torch_dtype, + ) self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, - config.hidden_size, - config.padded_vocab_size, - bias=False, - dtype=config.torch_dtype, - **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs, + ) self.pre_seq_len = config.pre_seq_len self.prefix_projection = config.prefix_projection if self.pre_seq_len is not None: @@ -797,10 +890,15 @@ def get_input_embeddings(self): return self.embedding.word_embeddings def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)) past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view(batch_size, self.pre_seq_len, self.num_layers * 2, - self.multi_query_group_num, self.kv_channels) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels, + ) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) @@ -821,7 +919,7 @@ def forward( output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) batch_size, seq_length = input_ids.shape @@ -830,12 +928,19 @@ def forward( if self.pre_seq_len is not None: if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, - device=input_ids.device, - dtype=inputs_embeds.dtype) + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], - dim=-1) + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): @@ -856,10 +961,16 @@ def forward( rotary_pos_emb=rotary_pos_emb, kv_caches=past_key_values, use_cache=use_cache, - output_hidden_states=output_hidden_states) + output_hidden_states=output_hidden_states, + ) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -870,6 +981,7 @@ def forward( def quantize(self, weight_bit_width: int): from .quantization import quantize + quantize(self.encoder, weight_bit_width) return self @@ -902,7 +1014,9 @@ def _update_model_kwargs_for_generation( if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, + ) # update position ids if "position_ids" in model_kwargs: @@ -914,13 +1028,15 @@ def _update_model_kwargs_for_generation( model_kwargs["is_first_forward"] = False return model_kwargs - def prepare_inputs_for_generation(self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - is_first_forward: bool = True, - **kwargs) -> dict: + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs, + ) -> dict: # only last token for input_ids if past is not None if position_ids is None: position_ids = self.get_position_ids(input_ids, device=input_ids.device) @@ -932,7 +1048,7 @@ def prepare_inputs_for_generation(self, "past_key_values": past_key_values, "position_ids": position_ids, "attention_mask": attention_mask, - "return_last_logit": True + "return_last_logit": True, } def forward( @@ -950,7 +1066,7 @@ def forward( return_last_logit: Optional[bool] = False, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) transformer_outputs = self.transformer( input_ids=input_ids, @@ -1034,17 +1150,19 @@ def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, st return inputs @torch.no_grad() - def chat(self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = None, - max_length: int = 8192, - num_beams=1, - do_sample=True, - top_p=0.8, - temperature=0.8, - logits_processor=None, - **kwargs): + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 8192, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs, + ): if history is None: history = [] if logits_processor is None: @@ -1057,7 +1175,7 @@ def chat(self, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, - **kwargs + **kwargs, } inputs = self.build_inputs(tokenizer, query, history=history) outputs = self.generate(**inputs, **gen_kwargs) @@ -1068,18 +1186,20 @@ def chat(self, return response, history @torch.no_grad() - def stream_chat(self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = None, - past_key_values=None, - max_length: int = 8192, - do_sample=True, - top_p=0.8, - temperature=0.8, - logits_processor=None, - return_past_key_values=False, - **kwargs): + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 8192, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs, + ): if history is None: history = [] if logits_processor is None: @@ -1091,7 +1211,7 @@ def stream_chat(self, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, - **kwargs + **kwargs, } if past_key_values is None and not return_past_key_values: inputs = self.build_inputs(tokenizer, query, history=history) @@ -1104,11 +1224,13 @@ def stream_chat(self, inputs.position_ids += past_length attention_mask = inputs.attention_mask attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - for outputs in self.stream_generate(**inputs, - past_key_values=past_key_values, - return_past_key_values=return_past_key_values, - **gen_kwargs): + inputs["attention_mask"] = attention_mask + for outputs in self.stream_generate( + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs, + ): if return_past_key_values: outputs, past_key_values = outputs outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] @@ -1138,12 +1260,15 @@ def stream_generate( generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + bos_token_id, eos_token_id = ( + generation_config.bos_token_id, + generation_config.eos_token_id, + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None) if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " @@ -1152,7 +1277,7 @@ def stream_generate( UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length) if not has_default_max_length: logger.warn( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" @@ -1163,14 +1288,14 @@ def stream_generate( ) if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids") logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`.") # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList()) + stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()) logits_processor = self._get_logits_processor( generation_config=generation_config, @@ -1237,9 +1362,11 @@ def quantize(self, bits: int, empty_init=False, device=None, **kwargs): self.config.quantization_bit = bits - self.transformer.encoder = quantize(self.transformer.encoder, - bits, - empty_init=empty_init, - device=device, - **kwargs) + self.transformer.encoder = quantize( + self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, + **kwargs, + ) return self From 243f7b8d742997ac99453baa65432ffb2c9b25f8 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Wed, 19 Jul 2023 14:19:10 +0800 Subject: [PATCH 15/15] [shardformer] fix chatglm configuration with pre-commit --- .../chatglm2_6b/configuration_chatglm.py | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py index 3730b4e34d86..3e78732be2da 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py @@ -3,34 +3,33 @@ class ChatGLMConfig(PretrainedConfig): model_type = "chatglm" - def __init__( - self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs - ): + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): self.num_layers = num_layers self.vocab_size = padded_vocab_size self.padded_vocab_size = padded_vocab_size @@ -56,4 +55,4 @@ def __init__( self.quantization_bit = quantization_bit self.pre_seq_len = pre_seq_len self.prefix_projection = prefix_projection - super().__init__(**kwargs) \ No newline at end of file + super().__init__(**kwargs)