Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ class PolicyLocation:
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
file_name="mistral", class_name="MistralModelPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation(
file_name="mistral", class_name="MistralForCausalLMPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
}

_INFER_POLICY_LIST = {
Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@


class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Bloom model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@


class FalconPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Falcon model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
14 changes: 14 additions & 0 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class MistralModelPolicy(MistralPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")

return super().module_policy()

class MistralForCausalLMPolicy(MistralPolicy):
def module_policy(self):
from transformers import MistralForCausalLM
Expand All @@ -150,6 +156,10 @@ def module_policy(self):
]
)
}

if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")

policy.update(new_item)

return policy
Expand All @@ -171,5 +181,9 @@ def module_policy(self):
]
)
}

if self.pipeline_stage_manager:
warnings.warn("Mistral dosen't support pipeline parallelism now.")

policy.update(new_item)
return policy
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@


class OPTPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The OPT model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@


class WhisperPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
version_string = transformers.__version__
version_tuple = tuple(map(int, version_string.split('.')[:3]))
assert version_tuple <= (4, 33, 0), "The Whisper model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass

Expand Down
4 changes: 4 additions & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
from .t5 import *
from .vit import *
from .whisper import *
try:
from .mistral import *
except ImportError:
print("This version of transformers doesn't support mistral.")
79 changes: 79 additions & 0 deletions tests/kit/model_zoo/transformers/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

from transformers import MistralConfig

# ===============================
# Register single-sentence Mistral
# ===============================

def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)

def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data

def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data

# define output transform function
output_transform_fn = lambda x: x

# define loss function
loss_fn_for_mistral_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()

config = MistralConfig(
hidden_size=256,
intermediate_size=256,
num_attention_heads=64,
num_hidden_layers=2,
vocab_size=50258
)

model_zoo.register(
name="transformers_mistral",
model_fn=lambda: transformers.MistralModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_mistral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mistral_for_casual_lm",
model_fn=lambda: transformers.MistralForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mistral_for_sequence_classification",
model_fn=lambda: transformers.MistralForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
148 changes: 148 additions & 0 deletions tests/test_shardformer/test_model/test_shard_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os

import pytest
import torch

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"


def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)

org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)

stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group

# unwrap model
mistral_model = unwrap_model(org_model, "MistralModel", "model")
shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model")

row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]

# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 5e-5, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mistral_model, shard_mistral_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
col_layer_grads = get_grad_tensors_for_check(
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)

# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()

# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3

if org_model.__class__.__name__ == "MistralModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)

check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)

# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)

# check grads
check_all_grad_tensors(grads_to_check)

torch.cuda.empty_cache()


@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_mistral_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral")

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()


def check_mistral(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mistral_test()



@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mistral():
spawn(check_mistral, 4)


if __name__ == "__main__":
test_mistral()