Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:

LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
Expand Down
18 changes: 7 additions & 11 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import warnings
from typing import List, Optional, Tuple
from typing import Optional, Tuple

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.utils import logging


def get_mistral_flash_attention_forward():
Expand All @@ -29,8 +21,12 @@ def forward(
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = (
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand Down
8 changes: 5 additions & 3 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Bloom model should run on a transformers version not greater than 4.33.0."
from packaging.version import Version

assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Bloom model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass
Expand Down
8 changes: 5 additions & 3 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ class FalconPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Falcon model should run on a transformers version not greater than 4.33.0."
from packaging.version import Version

assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Falcon model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass
Expand Down
15 changes: 9 additions & 6 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union
from typing import Dict, Union

import torch.nn as nn
from torch import Tensor
from torch.nn import Module

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

Expand Down Expand Up @@ -37,13 +34,16 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
warnings.warn(
"Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)

if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}

policy[MistralDecoderLayer] = ModulePolicyDescription(
Expand Down Expand Up @@ -129,6 +129,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
def postprocess(self):
return self.model


class MistralModelPolicy(MistralPolicy):
def __init__(self) -> None:
super().__init__()
Expand All @@ -139,6 +140,7 @@ def module_policy(self):

return super().module_policy()


class MistralForCausalLMPolicy(MistralPolicy):
def module_policy(self):
from transformers import MistralForCausalLM
Expand All @@ -164,6 +166,7 @@ def module_policy(self):

return policy


class MistralForSequenceClassificationPolicy(MistralPolicy):
def module_policy(self):
from transformers import MistralForSequenceClassification
Expand Down
8 changes: 5 additions & 3 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ class OPTPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The OPT model should run on a transformers version not greater than 4.33.0."
from packaging.version import Version

assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The OPT model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass
Expand Down
8 changes: 5 additions & 3 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class WhisperPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Whisper model should run on a transformers version not greater than 4.33.0."
from packaging.version import Version

assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Whisper model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass
Expand Down
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .t5 import *
from .vit import *
from .whisper import *

try:
from .mistral import *
except ImportError:
Expand Down
15 changes: 7 additions & 8 deletions tests/kit/model_zoo/transformers/mistral.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
import transformers
from transformers import MistralConfig

from ..registry import ModelAttribute, model_zoo

from transformers import MistralConfig

# ===============================
# Register single-sentence Mistral
# ===============================


def data_gen():
# Generated from following code snippet
#
Expand All @@ -18,23 +18,26 @@ def data_gen():
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)


def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data


def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data


# define output transform function
output_transform_fn = lambda x: x

Expand All @@ -46,11 +49,7 @@ def data_gen_for_sequence_classification():
loss_fn_for_seq_classification = lambda output: output.logits.mean()

config = MistralConfig(
hidden_size=256,
intermediate_size=256,
num_attention_heads=64,
num_hidden_layers=2,
vocab_size=50258
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
)

model_zoo.register(
Expand Down
38 changes: 29 additions & 9 deletions tests/test_shardformer/test_model/test_shard_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mistral_model, shard_mistral_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
mistral_model,
shard_mistral_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
mistral_model,
shard_mistral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
Expand Down Expand Up @@ -81,7 +95,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
check_weight(
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
mistral_model,
shard_mistral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)

# check grads
Expand All @@ -101,10 +122,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
Expand Down Expand Up @@ -135,7 +156,6 @@ def check_mistral(rank, world_size, port):
run_mistral_test()



@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
Expand All @@ -145,4 +165,4 @@ def test_mistral():


if __name__ == "__main__":
test_mistral()
test_mistral()