From 78a9bdbccabcb3c1e27892b537cce87b2f62bb8f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 1 Feb 2021 18:21:39 -0800 Subject: [PATCH 1/4] copy base from mp-pp --- src/transformers/integrations.py | 149 +++++++++++++++++++++- src/transformers/modeling_utils.py | 67 ++++++++++ src/transformers/models/t5/modeling_t5.py | 7 + src/transformers/trainer.py | 63 ++++++++- src/transformers/training_args.py | 7 + 5 files changed, 291 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index f66989381d24..720a4acd124e 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -257,7 +257,153 @@ def rewrite_logs(d): return new_d -def init_deepspeed(trainer, num_training_steps): + +import torch + + +# Model parallel group that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +_MODEL_PARALLEL_GROUP_DEVICE_IDS = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_GROUP_DEVICE_IDS = None + +# adjusted from Megatron-LM/mpu/ +class MPU: + def initialize_model_parallel(self, model_parallel_size_): + """ + Initialize model data parallel groups. + + Arguments: + model_parallel_size: number of GPUs used to parallelize model. + **Important**: not the total number of gpus! + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model. The present function will + create 4 model parallel groups and 2 data parallel groups as: + 4 model parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 data parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + + Let's say we have a total of 4 GPUs denoted by g0 ... g3 and we + use 2 GPUs to parallelize the model. The present function will + create 2 model parallel groups and 2 data parallel groups as: + 2 model parallel groups: + [g0, g1], [g2, g3] + 2 data parallel groups: + [g0, g2], [g1, g3] + + """ + + def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + + if torch.distributed.get_rank() == 0: + print("> initializing model parallel with size {}".format(model_parallel_size_)) + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + model_parallel_size = min(model_parallel_size_, world_size) + ensure_divisibility(world_size, model_parallel_size) + rank = torch.distributed.get_rank() + + print(f"MP size: {model_parallel_size}") + print(f"world_size: {world_size}") + print(f"rank: {rank}") + + # Build the data parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_DEVICE_IDS + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + for i in range(model_parallel_size): + ranks = range(i, world_size, model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank % model_parallel_size): + #print(f"DP ranks: {list(ranks)}") + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks) + + # Build the model parallel groups. + global _MODEL_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP_DEVICE_IDS + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + for i in range(world_size // model_parallel_size): + ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // model_parallel_size): + #print(f"MP ranks: {list(ranks)}") + _MODEL_PARALLEL_GROUP = group + _MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks) + + def model_parallel_is_initialized(self): + """Check if model and data parallel groups are initialized.""" + if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + def get_model_parallel_group_device_ids(self): + """Get the model parallel device ids of the group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP_DEVICE_IDS + + def get_model_parallel_group(self): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP + + def get_data_parallel_group_device_ids(self): + """Get the data parallel device ids of the group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP_DEVICE_IDS + + def get_data_parallel_group(self): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP + + def get_model_parallel_world_size(self): + """Return world size for the model parallel group.""" + return torch.distributed.get_world_size(group=self.get_model_parallel_group()) + + def get_model_parallel_rank(self): + """Return my rank for the model parallel group.""" + return torch.distributed.get_rank(group=self.get_model_parallel_group()) + + def get_model_parallel_src_rank(self): + """Calculate the global rank corresponding to a local rank zero + in the model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + def get_data_parallel_world_size(self): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=self.get_data_parallel_group()) + + def get_data_parallel_rank(self): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=self.get_data_parallel_group()) + + def destroy_model_parallel(self): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP_DEVICE_IDS + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_DEVICE_IDS + _MODEL_PARALLEL_GROUP = None + _MODEL_PARALLEL_GROUP_DEVICE_IDS = None + _DATA_PARALLEL_GROUP = None + _DATA_PARALLEL_GROUP_DEVICE_IDS = None + + +def init_deepspeed(trainer, num_training_steps, mpu): """ Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration @@ -403,6 +549,7 @@ def init_deepspeed(trainer, num_training_steps): model=model, model_parameters=model_parameters, config_params=config, + mpu = mpu, ) return model, optimizer, lr_scheduler diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d0fc1ad0f4b2..7a2b53773543 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1785,3 +1785,70 @@ def forward(self, hidden_states): return torch.cat(output_chunks, dim=chunk_dim) return forward_fn(*input_tensors) + + +def recursive_to(device, item): + """ + Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, + tuple and dict structures. + """ + + if torch.is_tensor(item): + return item.to(device) + + elif isinstance(item, list): + for i, x in enumerate(item): + item[i] = recursive_to(device, x) + return item + + elif isinstance(item, tuple): + return tuple(recursive_to(device, list(item))) + + elif isinstance(item, dict): + for k, v in item.items(): + item[k] = recursive_to(device, v) + return item + + else: + return item + + +# tnone = torch.tensor([float('nan')]*batch_size) +def pipe_none_or_empty_to_torch(x, batch_size, device): + tnone = torch.tensor([-100] * batch_size).to(device) + tempty = torch.empty(0).to(device) + if x is None: + return tnone.to(device) + if x == (): + return tempty.to(device) + return x + + +def pipe_torch_to_none_or_empty(x, batch_size, device): + tnone = torch.tensor([-100] * batch_size).to(device) + # tempty = torch.empty(0).to(device) + # if torch.is_tensor(x): + # print(x.shape, x) + # else: + # print(x) + if torch.is_tensor(x) and x.shape[0] == batch_size: + if not x.numel(): + return () + # print(x.numel(), batch_size, x, tnone) + if x.shape == tnone.shape and all(x == tnone): + return None + return x + + +def pipe_encode_all(input, batch_size, device): + input = list(input) + for i, x in enumerate(input): + input[i] = pipe_none_or_empty_to_torch(x, batch_size, device) + return tuple(input) + + +def pipe_decode_all(input, batch_size, device): + input = list(input) + for i, x in enumerate(input): + input[i] = pipe_torch_to_none_or_empty(x, batch_size, device) + return tuple(input) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index bd05cf00d11d..195c2c44f0ca 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1416,6 +1416,13 @@ def deparallelize(self): self.device_map = None torch.cuda.empty_cache() + + def pipeline_enable(self, chunks, device_map, mpu=None): + logger.info(f"enabling pipeline with chunks={chunks}") + + def pipeline_finalize(self): + pass + def get_input_embeddings(self): return self.shared diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c25c7cb42d79..25f25986e531 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -258,6 +258,58 @@ def __init__( else: self.is_model_parallel = False + + self.mpu = None + # XXX: for now hack over naive MP to have the same behavior + if len(self.args.pipeline): + # using range() syntax for upper boundary (i.e. not inclusive) + # --pipeline "chunks=5; device_map=0:0-10,1:10-20" + self.is_model_parallel = True + + # arg parser + pp_args = {} + pp_args_str = self.args.pipeline.split() + if len(pp_args_str): + for x in pp_args_str: + k,v = x.split("=") + pp_args[k] = v + + if "chunks" in pp_args: + pp_args["chunks"] = int(pp_args["chunks"]) + else: + # XXX: probably can try some smart dynamic default based on batch_size + pp_args["chunks"] = 2 + + if "device_map" in pp_args: + device_map_range_str = pp_args["device_map"].split(",") + device_map = {} + for x in device_map_range_str: + device_id, layers = x.split(":") + device_map[int(device_id)] = list(range(*map(int, layers.split("-")))) + pp_args["device_map"] = device_map + else: + pp_args["device_map"] = None + + if "n_gpus_per_mp" in pp_args: + pp_args["n_gpus_per_mp"] = int(pp_args["n_gpus_per_mp"]) + else: + # XXX: can try some smart dynamic default here based on total_n_gpus, + # if it's not 2D all gpus will be used + # if 2D half gpus should be a good default + pp_args["n_gpus_per_mp"] = 2 + + # 2D Parallel + if self.args.deepspeed: + from .integrations import MPU + self.mpu = MPU() + #n_gpus = torch.distributed.get_world_size() + # XXX: hardcoded for 2 gpus for PP/MP - needs to be configurable + #n_gpus_per_mp = n_gpus/2 + # at the moment experimenting with just 4 gpus - hence 2 gpus for MP|PP, 2 for DP + self.mpu.initialize_model_parallel(pp_args["n_gpus_per_mp"]) + + model.pipeline_enable(chunks=pp_args["chunks"], device_map=pp_args["device_map"], mpu=self.mpu) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -271,6 +323,9 @@ def __init__( # Force n_gpu to 1 to avoid DataParallel. self.args._n_gpu = 1 + if len(self.args.pipeline): + model = model.to(args.device) + # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model self.model = model @@ -760,7 +815,7 @@ def train( num_update_steps_per_epoch = max_steps if self.args.deepspeed: - model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) + model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps, mpu=self.mpu) self.model = model.module self.model_wrapped = model # will get further wrapped in DDP self.deepspeed = model # DeepSpeedEngine object @@ -995,6 +1050,9 @@ def train( # Clean the state at the end of training delattr(self, "_past") + if len(self.args.pipeline): + self.model.pipeline_finalize() + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: logger.info( @@ -1663,6 +1721,9 @@ def prediction_loop( if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + if len(self.args.pipeline): + self.model.pipeline_finalize() + return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) def _gather_and_numpify(self, tensors, name): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1e669b72196b..030177b6bc79 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -369,6 +369,13 @@ class TrainingArguments: ) debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"}) + pipeline: str = field( + default="", + metadata={ + "help": "Whether to enable Pipeline Parallelism and the value is pipeline params: 'chunks=5 device_map=0:1-10,1:10-20 n_gpus_per_pp=2" + }, + ) + dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) From d4b4476473f17a520aa6a10db4beb0c9642bed78 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 1 Feb 2021 20:36:39 -0800 Subject: [PATCH 2/4] style --- src/transformers/integrations.py | 7 +++---- src/transformers/trainer.py | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 10d06e9f940f..3a41095df194 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -258,7 +258,6 @@ def rewrite_logs(d): return new_d - import torch @@ -327,7 +326,7 @@ def ensure_divisibility(numerator, denominator): ranks = range(i, world_size, model_parallel_size) group = torch.distributed.new_group(ranks) if i == (rank % model_parallel_size): - #print(f"DP ranks: {list(ranks)}") + # print(f"DP ranks: {list(ranks)}") _DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks) @@ -339,7 +338,7 @@ def ensure_divisibility(numerator, denominator): ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) group = torch.distributed.new_group(ranks) if i == (rank // model_parallel_size): - #print(f"MP ranks: {list(ranks)}") + # print(f"MP ranks: {list(ranks)}") _MODEL_PARALLEL_GROUP = group _MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks) @@ -550,7 +549,7 @@ def init_deepspeed(trainer, num_training_steps, mpu): model=model, model_parameters=model_parameters, config_params=config, - mpu = mpu, + mpu=mpu, ) return model, optimizer, lr_scheduler diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 25f25986e531..fa6fe4716e72 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -258,7 +258,6 @@ def __init__( else: self.is_model_parallel = False - self.mpu = None # XXX: for now hack over naive MP to have the same behavior if len(self.args.pipeline): @@ -271,7 +270,7 @@ def __init__( pp_args_str = self.args.pipeline.split() if len(pp_args_str): for x in pp_args_str: - k,v = x.split("=") + k, v = x.split("=") pp_args[k] = v if "chunks" in pp_args: @@ -301,10 +300,11 @@ def __init__( # 2D Parallel if self.args.deepspeed: from .integrations import MPU + self.mpu = MPU() - #n_gpus = torch.distributed.get_world_size() + # n_gpus = torch.distributed.get_world_size() # XXX: hardcoded for 2 gpus for PP/MP - needs to be configurable - #n_gpus_per_mp = n_gpus/2 + # n_gpus_per_mp = n_gpus/2 # at the moment experimenting with just 4 gpus - hence 2 gpus for MP|PP, 2 for DP self.mpu.initialize_model_parallel(pp_args["n_gpus_per_mp"]) From 4c0ea522157f693bccce80c4cbecc24019186676 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 1 Feb 2021 20:36:51 -0800 Subject: [PATCH 3/4] sequential stage 1 --- src/transformers/models/t5/modeling_t5.py | 195 +++++++++++++++------- 1 file changed, 134 insertions(+), 61 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 195c2c44f0ca..74bb09d0baed 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -776,6 +776,95 @@ def _shift_right(self, input_ids): return shifted_input_ids +class T5StackPipeSegment(nn.Module): + def __init__( + self, + idx, + layer_module, + is_decoder, + ): + super().__init__() + self.idx = idx + self.layer_module = layer_module + self.is_decoder = is_decoder + + def forward(self, input): + #print(f"!!!!!!!!!!!!!!!!!!! {self.idx}") + + # unpack + ( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + head_mask, + encoder_head_mask, + past_key_values, + use_cache, + output_attentions, + all_hidden_states, + output_hidden_states, + present_key_value_states, + ) = input + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = self.layer_module( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=head_mask[self.idx], + encoder_layer_head_mask=encoder_head_mask[self.idx], + past_key_value=past_key_values[self.idx], + use_cache=use_cache, + output_attentions=output_attentions, + ) + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention weights), + # (self-attention position bias), (cross-attention weights), (cross-attention position bias) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # pack + outputs = ( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + head_mask, + encoder_head_mask, + past_key_values, + use_cache, + output_attentions, + all_hidden_states, + output_hidden_states, + present_key_value_states, + ) + + return outputs + + class T5Stack(T5PreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) @@ -919,67 +1008,52 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): - layer_head_mask = head_mask[i] - encoder_layer_head_mask = encoder_head_mask[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) - if layer_head_mask is not None: - layer_head_mask = layer_head_mask.to(hidden_states.device) - if encoder_layer_head_mask is not None: - encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - encoder_layer_head_mask=encoder_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, + # rewrite the model after pre-trained weights were loaded + layers = [ + T5StackPipeSegment( + idx, + layer_module, + self.is_decoder, ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention weights), - # (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) + for idx, layer_module in enumerate(self.block) + ] + net = nn.Sequential(*layers) + + input = ( + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + head_mask, + encoder_head_mask, + past_key_values, + use_cache, + output_attentions, + all_hidden_states, + output_hidden_states, + present_key_value_states, + ) + output = net(input) + + # unpack + ( + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + head_mask, + encoder_head_mask, + past_key_values, + use_cache, + output_attentions, + all_hidden_states, + output_hidden_states, + present_key_value_states, + ) = output hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1416,7 +1490,6 @@ def deparallelize(self): self.device_map = None torch.cuda.empty_cache() - def pipeline_enable(self, chunks, device_map, mpu=None): logger.info(f"enabling pipeline with chunks={chunks}") From d3d4f9c900450852fce54a89712968b700ff02da Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 1 Feb 2021 21:44:55 -0800 Subject: [PATCH 4/4] sequential stage 2 - split t5stack --- src/transformers/models/t5/modeling_t5.py | 91 +++++++++-------------- 1 file changed, 37 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 74bb09d0baed..76a57bf8c885 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -807,6 +807,8 @@ def forward(self, input): all_hidden_states, output_hidden_states, present_key_value_states, + all_attentions, + all_cross_attentions, ) = input if output_hidden_states: @@ -860,6 +862,8 @@ def forward(self, input): all_hidden_states, output_hidden_states, present_key_value_states, + all_attentions, + all_cross_attentions, ) return outputs @@ -871,50 +875,16 @@ def __init__(self, config, embed_tokens=None): self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder + self.config = config self.block = nn.ModuleList( [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) self.init_weights() - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - # Load onto devices - for k, v in self.device_map.items(): - for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - # Set embed_tokens to first layer - self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device - self.final_layer_norm = self.final_layer_norm.to(self.last_device) - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def deparallelize(self): - self.model_parallel = False - self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") - torch.cuda.empty_cache() def get_input_embeddings(self): return self.embed_tokens @@ -937,10 +907,6 @@ def forward( output_hidden_states=None, return_dict=None, ): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(self.first_device) - self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1008,17 +974,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - # rewrite the model after pre-trained weights were loaded - layers = [ - T5StackPipeSegment( - idx, - layer_module, - self.is_decoder, - ) - for idx, layer_module in enumerate(self.block) - ] - net = nn.Sequential(*layers) - + # pack input = ( hidden_states, extended_attention_mask, @@ -1034,9 +990,34 @@ def forward( all_hidden_states, output_hidden_states, present_key_value_states, + all_attentions, + all_cross_attentions, ) - output = net(input) + # rewrite the model after pre-trained weights were loaded + layers = [ + T5StackPipeSegment( + idx, + layer_module, + self.is_decoder, + ).to(input_ids.device) + for idx, layer_module in enumerate(self.block) + ] + + layers.append(T5StackLast(return_dict, self.final_layer_norm, self.dropout).to(input_ids.device)) + net = nn.Sequential(*layers) + return net(input) + + +class T5StackLast(nn.Module): + def __init__(self, return_dict, final_layer_norm, dropout): + super().__init__() + + self.return_dict = return_dict + self.final_layer_norm = final_layer_norm + self.dropout = dropout + + def forward(self, input): # unpack ( hidden_states, @@ -1053,7 +1034,9 @@ def forward( all_hidden_states, output_hidden_states, present_key_value_states, - ) = output + all_attentions, + all_cross_attentions, + ) = input hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1062,7 +1045,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: + if not self.return_dict: return tuple( v for v in [