Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions colossalai/shardformer/layer/embedding1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,22 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()

self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)

self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
# self.gather_output = gather_output
self.gather_output = gather_output

if device is None:
device = get_current_device()
Expand All @@ -95,7 +96,9 @@ def __init__(self,

@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
Expand Down Expand Up @@ -123,7 +126,9 @@ def from_native_module(module: nn.Embedding,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
sparse=sparse,
*args,
**kwargs)

# copy the weight
with torch.no_grad():
Expand All @@ -133,7 +138,7 @@ def from_native_module(module: nn.Embedding,
return embedding

def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

Expand All @@ -144,6 +149,9 @@ def _fill_padding_idx_with_zero(self) -> None:

def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)

return output
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand Down
3 changes: 1 addition & 2 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
T5Stack,
)

from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand Down
11 changes: 9 additions & 2 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,14 @@ def _replace_sub_module(
if description.ignore_if_not_exist and native_sub_module is None:
continue

replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
f" with {target_module.__qualname__} with the exception: {e}. "
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)

setattr_(org_layer, suffix, replace_layer)
2 changes: 1 addition & 1 deletion colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"
30 changes: 19 additions & 11 deletions tests/kit/model_zoo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,35 @@ def register(self,
model_fn: Callable,
data_gen_fn: Callable,
output_transform_fn: Callable,
loss_fn: Callable = None,
model_attribute: ModelAttribute = None):
"""
Register a model and data generation function.

Examples:
>>> # Register
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnresnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)

```python
# normal forward workflow
model = resnet18()
data = resnet18_data_gen()
output = model(**data)
transformed_output = output_transform_fn(output)
loss = loss_fn(transformed_output)

# Register
model_zoo = ModelZooRegistry()
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
```

Args:
name (str): Name of the model.
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
output_transform_fn (callable): A function that transforms the output of the model into Dict.
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
output_transform_fn (Callable): A function that transforms the output of the model into Dict.
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)

def get_sub_registry(self, keyword: str):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .gpt import *
from .llama import *
from .opt import *
from .t5 import *
76 changes: 76 additions & 0 deletions tests/kit/model_zoo/transformers/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

try:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
HAS_LLAMA = True
except ImportError:
HAS_LLAMA = False

if HAS_LLAMA:
# ===============================
# Register LLaMA
# ===============================

def data_gen():
# the input ids are corresponding to the sentence
# 'Hello, my dog is cute'
#
# the code is give below:
# -----------------------------------
# from transformers import LlamaTokenizerFast
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------

input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)

# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data

# transform the output to a dict
output_transform_fn = lambda x: x

# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()

config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)

# register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaForSequenceClassification,
model_zoo.register(name='transformers_llama',
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_casual_lm',
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_sequence_classification',
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True))
53 changes: 41 additions & 12 deletions tests/kit/model_zoo/transformers/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,70 @@
# ===============================
# Register single-sentence T5
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16


def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)


# define data gen function
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
# Generated from following code snippet
#
# from transformers import T5Config, T5Tokenizer
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
return dict(input_ids=input_ids)


def data_gen_for_conditional_generation():
# labels is generated with the following code
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data


def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data


# output transform function
output_transform_fn = lambda x: x

config = transformers.T5Config(d_model=128, num_layers=2)
# define loss funciton
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss

# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)

# register the following models
# transformers.T5Model,
# transformers.T5ForConditionalGeneration,
# transformers.T5EncoderModel,
model_zoo.register(name='transformers_t5',
model_fn=lambda: transformers.T5Model(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_t5_model,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_t5_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_for_conditional_generation',
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_conditional_generation,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_encoder_model',
model_fn=lambda: transformers.T5EncoderModel(config),
data_gen_fn=data_gen_for_encoder_only,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True))
2 changes: 1 addition & 1 deletion tests/test_booster/test_mixed_precision/test_fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip
if name == 'dlrm_interactionarch':
continue
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair

for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = []

for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# FIXME(ver217): fix these models
if name in ignore_models:
skipped_models.append(name)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):


def check_torch_ddp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if name == 'dlrm_interactionarch':
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):


def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if any(element in name for element in [
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
'torchvision_inception_v3'
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_diffusers():

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize()
Expand All @@ -60,7 +60,7 @@ def test_torch_diffusers():

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
model = model_fn()
output = model(**data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_timm_models():

sub_model_zoo = model_zoo.get_sub_registry('timm')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}
Expand Down
Loading