Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def __init__(
self.pg_mesh,
pipeline_axis=PP_AXIS,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks
num_model_chunks=num_model_chunks,
)

if pp_style == "interleaved":
Expand All @@ -405,9 +405,7 @@ def __init__(
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
else:
raise NotImplementedError()
Expand Down
12 changes: 4 additions & 8 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,13 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
if self.num_microbatch is not None:
assert (
self.batch_size % self.num_microbatch == 0
), "Batch size should divided by the number of microbatch"
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
elif self.microbatch_size is not None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
else:
raise ValueError(
"Either num_microbatch or microbatch_size should be provided")
raise ValueError("Either num_microbatch or microbatch_size should be provided")

assert (
self.num_microbatch % self.num_model_chunks == 0
Expand Down Expand Up @@ -323,10 +320,9 @@ def forward_backward_step(
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)

if num_microbatch_remaining == 0 \
and i + 1 == num_warmup_microbatch:
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch:
break

model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)

Expand Down
14 changes: 8 additions & 6 deletions colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""Is the current stage the first stage.

NOTE:
NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device()

Expand All @@ -79,8 +79,9 @@ def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""
if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
assert self.is_interleave ^ (model_chunk_id is None), \
"model_chunk_id must be specified when using interleaved pipeline"
assert self.is_interleave ^ (
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == 0
else:
Expand All @@ -89,7 +90,7 @@ def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""Is the current stage the last stage.

NOTE:
NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device()

Expand All @@ -98,8 +99,9 @@ def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""
if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
assert self.is_interleave ^ (model_chunk_id is None), \
"model_chunk_id must be specified when using interleaved pipeline"
assert self.is_interleave ^ (
model_chunk_id is None
), "model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == self.num_stages - 1
else:
Expand Down
51 changes: 20 additions & 31 deletions colossalai/shardformer/modeling/falcon.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,30 @@

import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from colossalai.pipeline.stage_manager import PipelineStageManager
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F

from transformers.utils import logging

from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)

from transformers.models.falcon.modeling_falcon import (
FalconForCausalLM,
FalconForQuestionAnswering,
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
build_alibi_tensor,
)
from transformers.models.falcon.modeling_falcon import build_alibi_tensor
from transformers.utils import logging

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig


def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
def build_falcon_alibi_tensor(
self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
Expand Down Expand Up @@ -98,7 +93,7 @@ def build_falcon_alibi_tensor(

def get_tp_falcon_decoder_layer_forward():
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, dropout_add

def forward(
self: FalconDecoderLayer,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -155,16 +150,16 @@ def forward(
outputs = (output,) + outputs[1:]

return outputs # hidden_states, present, attentions

return forward


def get_falcon_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention

def forward(
self: FalconAttention,
Expand All @@ -191,11 +186,9 @@ def forward(
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)


past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)


if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
Expand All @@ -220,12 +213,10 @@ def forward(
attention_mask_float = (
attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)

batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
tgt_len = key_layer_.size()[1]
attention_mask_float = attention_mask_float.expand(
batch_size, self.num_heads, src_len, tgt_len
).contiguous()
attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
context_layer = me_attention(
query_layer_,
key_layer_,
Expand All @@ -236,7 +227,7 @@ def forward(
)
batch_size, seq_length, _, _ = context_layer.shape
context_layer = context_layer.reshape(batch_size, seq_length, -1)

output_tensor = self.dense(context_layer)

return output_tensor, present
Expand Down Expand Up @@ -280,7 +271,7 @@ def falcon_model_forward(
if past_key_values is not None:
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if past_key_values is None:
Expand Down Expand Up @@ -394,10 +385,11 @@ def custom_forward(*inputs):
if presents is not None:
presents = self._convert_cache_to_standard_format(presents, batch_size)


if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
Expand All @@ -407,7 +399,6 @@ def custom_forward(*inputs):
else:
# always return dict for imediate stage
return {"hidden_states": hidden_states}


@staticmethod
def falcon_for_causal_lm_forward(
Expand All @@ -434,7 +425,7 @@ def falcon_for_causal_lm_forward(
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if output_attentions:
Expand Down Expand Up @@ -489,11 +480,10 @@ def falcon_for_causal_lm_forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}


@staticmethod
def falcon_for_sequence_classification_forward(
Expand Down Expand Up @@ -552,7 +542,7 @@ def falcon_for_sequence_classification_forward(
batch_size = hidden_states.shape[0]
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
Expand Down Expand Up @@ -605,7 +595,6 @@ def falcon_for_sequence_classification_forward(
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}


@staticmethod
def falcon_for_token_classification_forward(
Expand Down Expand Up @@ -684,11 +673,11 @@ def falcon_for_token_classification_forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}

@staticmethod
def falcon_for_question_answering_forward(
self: FalconForQuestionAnswering,
Expand Down Expand Up @@ -780,4 +769,4 @@ def falcon_for_question_answering_forward(
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
return {"hidden_states": hidden_states}
28 changes: 8 additions & 20 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,38 +257,28 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli

if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(
new_forward,
stage_manager=stage_manager,
shard_config=self.shard_config
)
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}

else:
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}

self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
Expand All @@ -304,14 +294,13 @@ def get_held_layers(self) -> List[Module]:
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(-1):
held_layers.append(module.embeddings)
Expand Down Expand Up @@ -518,8 +507,7 @@ def get_held_layers(self) -> List[Module]:
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(
None if not stage_manager.is_interleave else -1):
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
Expand Down
Loading