Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ We will follow this roadmap to develop Shardformer:
- [ ] Audio
- [ ] Whisper
- [ ] Multi-modal
- [ ] To be added
- [x] SAM
- [x] BLIP-2

## 💡 API Design

Expand Down
60 changes: 60 additions & 0 deletions colossalai/shardformer/modeling/blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn


def forward_fn():

def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

bsz, tgt_len, embed_dim = hidden_states.size()

mixed_qkv = self.qkv(hidden_states)

# modified from original code, which is:
# mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
# 2, 0, 3, 1, 4
# )
# to:
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))

attention_scores = attention_scores * self.scale

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)

new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
context_layer = context_layer.reshape(new_context_layer_shape)

output = self.projection(context_layer)

outputs = (output, attention_probs) if output_attentions else (output, None)

return outputs

return forward
2 changes: 0 additions & 2 deletions colossalai/shardformer/modeling/sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup


def forward_fn():
Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class PolicyLocation:
# Sam
"transformers.models.sam.modeling_sam.SamModel":
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),

# Blip2
"transformers.models.blip_2.modeling_blip_2.Blip2Model":
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
}


Expand Down
304 changes: 304 additions & 0 deletions colossalai/shardformer/policies/blip2.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
from .gpt import *
from .llama import *
Expand Down
61 changes: 61 additions & 0 deletions tests/kit/model_zoo/transformers/blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-image SAM
# ===============================


# define data gen function
def data_gen():
# Generated from following code snippet
#
# from PIL import Image
# import requests
# from transformers import Blip2Processor, Blip2Model
# import torch

# processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)

# prompt = "Question: how many cats are there? Answer:"
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

pixel_values = torch.rand(1, 3, 224, 224, dtype=torch.float32)
input_ids = torch.tensor([[2, 45641, 35, 141, 171, 10017, 32, 89, 116, 31652, 35]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
labels = torch.tensor([[34, 56]], dtype=torch.int64)
return dict(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels)


# define output transform function
output_transform_fn = lambda x: x

# define loss funciton
loss_fn_blip2_model = lambda x: x.loss

config = transformers.Blip2Config()
config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1
config.qformer_config.attention_probs_dropout_prob = 0
config.qformer_config.hidden_dropout_prob = 0
config.text_config.dropout = 0

# register the blip2 variants
model_zoo.register(name='transformers_blip2',
model_fn=lambda: transformers.Blip2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True))

model_zoo.register(name='transformers_blip2_conditional_gerneration',
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True))
3 changes: 2 additions & 1 deletion tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert',
'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining',
'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base',
'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model'
'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model',
'transformers_blip2', 'transformers_vit', 'transformers_vit_for_masked_image_modeling'
]:
continue

Expand Down
3 changes: 2 additions & 1 deletion tests/test_booster/test_plugin/test_low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# These models will get stuck
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads', 'transformers_sam',
'transformers_bert_lm_head_model', 'transformers_bert_for_masked_lm', 'transformers_vit'
]


Expand Down
3 changes: 2 additions & 1 deletion tests/test_lazy/test_distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def run_dist_lazy_init(subset, seed: int = 42):

for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith(
('transformers_llama', 'transformers_blip2')):
continue
print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
Expand Down
107 changes: 107 additions & 0 deletions tests/test_shardformer/test_model/test_shard_blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
import torch

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])

# do backward
org_loss.backward()
shard_loss.backward()

assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"

# check grad

blip2 = org_model
sharded_blip2 = sharded_model

# compare vision_model grad

org_grad = blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
shard_grad = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
shard_weight = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# compare qformer grad
org_grad = blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
shard_weight = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# compare language_model grad
org_grad = blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
shard_grad = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
shard_weight = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"


@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()


def check_blip2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_blip2_test()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_blip2():
spawn(check_blip2, 2)


if __name__ == "__main__":
test_blip2()