Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ We will follow this roadmap to develop Shardformer:
- [ ] SwinTransformer
- [ ] SwinTransformer V2
- [ ] Audio
- [ ] Whisper
- [x] Whisper
- [ ] Multi-modal
- [ ] To be added

Expand Down
13 changes: 12 additions & 1 deletion colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def __init__(self,
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
Expand All @@ -206,6 +205,9 @@ def __init__(self,
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

# padding index
self.padding_idx = self._select_padding_idx(padding_idx)

# parameter
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
Expand Down Expand Up @@ -263,6 +265,15 @@ def _fill_padding_idx_with_zero(self) -> None:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)

def _select_padding_idx(self, padding_idx: int):
# select padding index according to the rank
if padding_idx is None:
return None
elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index:
return padding_idx - self.vocab_start_index
else:
return None

def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
Expand Down
8 changes: 8 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ class PolicyLocation:
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),

# Whisper
"transformers.models.whisper.modeling_whisper.WhisperModel":
PolicyLocation(file_name="whisper", class_name="WhisperModelPolicy"),
"transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration":
PolicyLocation(file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"),
"transformers.models.whisper.modeling_whisper.WhisperForAudioClassification":
PolicyLocation(file_name="whisper", class_name="WhisperForAudioClassificationPolicy"),

# Sam
"transformers.models.sam.modeling_sam.SamModel":
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
Expand Down
232 changes: 232 additions & 0 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import torch.nn as nn

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
]


class WhisperPolicy(Policy):

def config_sanity_check(self):
pass

def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
from transformers.models.whisper.modeling_whisper import (
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
WhisperEncoderLayer,
)

policy = {}

if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
),
])

policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size,
"encoder_attn.embed_dim":
self.model.config.d_model // self.shard_config.tensor_parallel_size,
"encoder_attn.num_heads":
self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="encoder_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="encoder_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
),
])

policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
),
])

# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle encoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperEncoderLayer)

# Handle decoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperDecoderLayer)

# handle encoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperEncoder)

# handle decoder layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperDecoder)
return policy

def add_lm_head_policy(self, base_policy):
from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration

# optimize for tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
policy=base_policy,
target_key=WhisperForConditionalGeneration)

return base_policy

def postprocess(self):
return self.model


# WhisperModel
class WhisperModelPolicy(WhisperPolicy):

def __init__(self) -> None:
super().__init__()


# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy)
return module_policy

def postprocess(self):
binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model


# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):

def __init__(self) -> None:
super().__init__()
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .sam import *
from .t5 import *
from .vit import *
from .whisper import *
91 changes: 91 additions & 0 deletions tests/kit/model_zoo/transformers/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-sentence Whisper
# ===============================


# define data gen function
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoFeatureExtractor, WhisperModel
# from datasets import load_dataset

# model = WhisperModel.from_pretrained("openai/whisper-base")
# feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
# ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
# input_features = inputs.input_features
# decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id

input_features = torch.randn(1, 80, 3000)
decoder_input_ids = torch.tensor([[1, 1]]) * 50258
return dict(input_features=input_features, decoder_input_ids=decoder_input_ids)


def data_gen_for_conditional_generation():
# labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
# Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
# or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
# only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
data = data_gen()
data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64)
return data


def data_gen_for_audio_classification():
# labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
# Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
# config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
# `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# `WhisperForAudioClassification` does not need `decoder_input_ids`
data = data_gen()
data.pop('decoder_input_ids')
data['labels'] = torch.tensor([1], dtype=torch.int64)
return data


# define output transform function
output_transform_fn = lambda x: x

# define loss funciton
loss_fn = lambda x: x.last_hidden_state.mean()
loss_fn_attr = lambda x: x.loss

config = transformers.WhisperConfig(
classifier_proj_size=256,
d_model=256,
decoder_attention_heads=4,
decoder_ffn_dim=1536,
decoder_layers=2,
encoder_attention_heads=4,
encoder_ffn_dim=1536,
encoder_layers=2,
vocab_size=51866,
)

# register the Whisper variants
model_zoo.register(name='transformers_whisper',
model_fn=lambda: transformers.WhisperModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

model_zoo.register(name='transformers_whisperForConditionalGeneration',
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True))

model_zoo.register(name='transformers_whisperWhisperForAudioClassification',
model_fn=lambda: transformers.WhisperForAudioClassification(config),
data_gen_fn=data_gen_for_audio_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True))
Loading