Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp

def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
param = shared_param[self.stage_manager.stage]
dist.all_reduce(param.grad, group=group)
if self.stage_manager.stage in shared_param:
param = shared_param[self.stage_manager.stage]
dist.all_reduce(param.grad, group=group)
dist.barrier()

def no_sync(self) -> Iterator[None]:
# no sync grads across data parallel
Expand Down
6 changes: 5 additions & 1 deletion colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import io
import pickle
import re
from typing import Any, List, Optional, Union

import torch
Expand Down Expand Up @@ -31,7 +32,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
if b'cuda' in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b'cuda', buf_array):
pos = cuda_str.start()
buf_array[pos + 5] = 48 + device_index
buf = bytes(buf_array)

io_bytes = io.BytesIO(buf)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/pipeline/schedule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def retain_grad(x: Any) -> None:
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor):
if isinstance(x, torch.Tensor) and x.requires_grad:
x.retain_grad()


Expand Down
11 changes: 9 additions & 2 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)

# Collect the grad of the input_obj.
input_obj_grad = None
Expand Down
7 changes: 7 additions & 0 deletions colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def increment_index():
"""
Randomizer._INDEX += 1

@staticmethod
def reset_index():
"""
Reset the index to zero.
"""
Randomizer._INDEX = 0

@staticmethod
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
"""
Expand Down
95 changes: 43 additions & 52 deletions colossalai/shardformer/modeling/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def custom_forward(*inputs):
return {
'hidden_states': hidden_states,
'position_bias': position_bias,
'encoder_decoder_position_bias': encoder_decoder_position_bias
'encoder_decoder_position_bias': encoder_decoder_position_bias,
'backward_tensor_keys': ['hidden_states']
}

@staticmethod
Expand All @@ -261,8 +262,10 @@ def t5_model_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
Expand Down Expand Up @@ -303,7 +306,6 @@ def t5_model_forward(
decoder_head_mask = head_mask

in_decoder = stage_manager.stage >= decoder_starting_stage

# Stage is in encoder, directly return the output of t5_stack_forward
if not in_decoder:
encoder_outputs = T5PipelineForwards.t5_stack_forward(
Expand All @@ -323,25 +325,18 @@ def t5_model_forward(
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
return {'encoder_outputs': encoder_outputs}
return {'encoder_hidden_states': encoder_outputs[0]}
else:
return encoder_outputs

at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage

if encoder_outputs is None:
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")

encoder_hidden_states = encoder_outputs[0]
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
if encoder_outputs is not None:
encoder_hidden_states = encoder_outputs[0]
elif encoder_hidden_states is None:
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")

# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")

Expand All @@ -360,6 +355,7 @@ def t5_model_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
Expand All @@ -368,22 +364,19 @@ def t5_model_forward(

# Directly return outputs of overloaded T5Stack forward if not at last stage.
if not at_last_decoder_stage:
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
# encoder_hidden_states should be passed to the next stage
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
return decoder_outputs

if not return_dict:
return decoder_outputs + encoder_outputs

return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
return decoder_outputs + encoder_hidden_states
else:
return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states)

@staticmethod
def t5_for_conditional_generation_forward(
Expand All @@ -406,8 +399,10 @@ def t5_for_conditional_generation_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
Expand Down Expand Up @@ -468,28 +463,25 @@ def t5_for_conditional_generation_forward(
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
return {'encoder_outputs': encoder_outputs}
return {'encoder_hidden_states': encoder_outputs[0]}
else:
return encoder_outputs

at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage

if encoder_outputs is None:
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
if encoder_outputs is not None:
encoder_hidden_states = encoder_outputs[0]
elif encoder_hidden_states is None:
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")

encoder_hidden_states = encoder_outputs[0]
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)

# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")

if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)

# Decode
decoder_outputs = T5PipelineForwards.t5_stack_forward(
self.decoder,
Expand All @@ -505,6 +497,7 @@ def t5_for_conditional_generation_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
Expand All @@ -513,7 +506,8 @@ def t5_for_conditional_generation_forward(

# Directly return outputs of overloaded T5Stack forward if not at last stage.
if not at_last_decoder_stage:
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
# encoder_hidden_states should be passed to the next stage
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
return decoder_outputs

sequence_output = decoder_outputs[0]
Expand All @@ -533,20 +527,16 @@ def t5_for_conditional_generation_forward(
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states
return ((loss,) + output) if loss is not None else output

return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
return Seq2SeqLMOutput(loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states)

@staticmethod
def t5_encoder_model_forward(
Expand All @@ -562,6 +552,7 @@ def t5_encoder_model_forward(
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
Expand Down
51 changes: 13 additions & 38 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def get_held_layers(self) -> List[nn.Module]:

model = self.model
encoder = self.model.encoder
decoder = self.model.__dict__.get('decoder', None)
decoder = getattr(self.model, 'decoder', None)

num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
Expand Down Expand Up @@ -300,7 +300,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
stage_manager = self.pipeline_stage_manager

encoder = self.model.encoder
decoder = self.model.__dict__.get('decoder', None)
decoder = getattr(self.model, 'decoder', None)

num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
Expand Down Expand Up @@ -355,15 +355,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
return []

def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)
for dst in v:
setattr_(self.model, dst, src)
return self.model


class T5ForConditionalGenerationPolicy(T5BasePolicy):

Expand Down Expand Up @@ -409,28 +400,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
stage_manager.num_stages)

shared_params = []
shared_embedding = {}
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
shared_params.append({
0: module.shared.weight,
decoder_starting_stage: module.decoder.embed_tokens.weight
})
shared_embedding[0] = module.shared.weight
shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight

if id(module.lm_head.weight) == id(module.shared.weight):
shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
return shared_params
return []
shared_embedding[0] = module.shared.weight
shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight

def postprocess(self):
super().postprocess()
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {
"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
}
for k, v in binding_map.items():
src = getattr_(self.model, k)
for dst in v:
setattr_(self.model, dst, src)
if len(shared_embedding) > 0:
shared_params.append(shared_embedding)

return self.model
return shared_params

return []


class T5EncoderPolicy(T5BasePolicy):
Expand Down Expand Up @@ -462,12 +446,3 @@ def get_held_layers(self) -> List[nn.Module]:

def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []

def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)
for dst in v:
setattr_(self.model, dst, src)
return self.model
16 changes: 15 additions & 1 deletion colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,28 @@ def _replace_sub_module(

setattr_(org_layer, suffix, replace_layer)

def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:

def collect_sub_modules(module: nn.Module):
if module is None:
return
recursive_held_layers.append(module)
for name, child in module.named_children():
collect_sub_modules(child)

recursive_held_layers = []
for module in held_layers:
collect_sub_modules(module)
return recursive_held_layers

def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
r"""
Release the unheld layers in the model
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
set_tensors_to_none(self.model, exclude=set(held_layers))
return set(held_layers)
return set(self._get_recursive_held_layers(held_layers))
return None

def _materialize(self) -> None:
Expand Down
Loading