diff --git a/examples/legacy/seq2seq/finetune_trainer.py b/examples/legacy/seq2seq/finetune_trainer.py index 37573e50bad7..fdb086194966 100755 --- a/examples/legacy/seq2seq/finetune_trainer.py +++ b/examples/legacy/seq2seq/finetune_trainer.py @@ -178,6 +178,8 @@ def main(): transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() # Set the verbosity to info of the Transformers logger (on main process only): + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() if is_main_process(training_args.local_rank): transformers.utils.logging.set_verbosity_info() logger.info("Training/evaluation parameters %s", training_args) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index e5293a1caaf9..4da47470831f 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -269,7 +269,153 @@ def rewrite_logs(d): return new_d -def init_deepspeed(trainer, num_training_steps): + +import torch + + +# Model parallel group that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +_MODEL_PARALLEL_GROUP_DEVICE_IDS = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_GROUP_DEVICE_IDS = None + +# adjusted from Megatron-LM/mpu/ +class MPU: + def initialize_model_parallel(self, model_parallel_size_): + """ + Initialize model data parallel groups. + + Arguments: + model_parallel_size: number of GPUs used to parallelize model. + **Important**: not the total number of gpus! + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model. The present function will + create 4 model parallel groups and 2 data parallel groups as: + 4 model parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 data parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + + Let's say we have a total of 4 GPUs denoted by g0 ... g3 and we + use 2 GPUs to parallelize the model. The present function will + create 2 model parallel groups and 2 data parallel groups as: + 2 model parallel groups: + [g0, g1], [g2, g3] + 2 data parallel groups: + [g0, g2], [g1, g3] + + """ + + def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + + if torch.distributed.get_rank() == 0: + print("> initializing model parallel with size {}".format(model_parallel_size_)) + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + model_parallel_size = min(model_parallel_size_, world_size) + ensure_divisibility(world_size, model_parallel_size) + rank = torch.distributed.get_rank() + + print(f"MP size: {model_parallel_size}") + print(f"world_size: {world_size}") + print(f"rank: {rank}") + + # Build the data parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_DEVICE_IDS + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + for i in range(model_parallel_size): + ranks = range(i, world_size, model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank % model_parallel_size): + #print(f"DP ranks: {list(ranks)}") + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks) + + # Build the model parallel groups. + global _MODEL_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP_DEVICE_IDS + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + for i in range(world_size // model_parallel_size): + ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // model_parallel_size): + #print(f"MP ranks: {list(ranks)}") + _MODEL_PARALLEL_GROUP = group + _MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks) + + def model_parallel_is_initialized(self): + """Check if model and data parallel groups are initialized.""" + if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + def get_model_parallel_group_device_ids(self): + """Get the model parallel device ids of the group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP_DEVICE_IDS + + def get_model_parallel_group(self): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP + + def get_data_parallel_group_device_ids(self): + """Get the data parallel device ids of the group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP_DEVICE_IDS + + def get_data_parallel_group(self): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP + + def get_model_parallel_world_size(self): + """Return world size for the model parallel group.""" + return torch.distributed.get_world_size(group=self.get_model_parallel_group()) + + def get_model_parallel_rank(self): + """Return my rank for the model parallel group.""" + return torch.distributed.get_rank(group=self.get_model_parallel_group()) + + def get_model_parallel_src_rank(self): + """Calculate the global rank corresponding to a local rank zero + in the model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + def get_data_parallel_world_size(self): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=self.get_data_parallel_group()) + + def get_data_parallel_rank(self): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=self.get_data_parallel_group()) + + def destroy_model_parallel(self): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP_DEVICE_IDS + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_DEVICE_IDS + _MODEL_PARALLEL_GROUP = None + _MODEL_PARALLEL_GROUP_DEVICE_IDS = None + _DATA_PARALLEL_GROUP = None + _DATA_PARALLEL_GROUP_DEVICE_IDS = None + + +def init_deepspeed(trainer, num_training_steps, mpu): """ Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration @@ -415,6 +561,7 @@ def init_deepspeed(trainer, num_training_steps): model=model, model_parameters=model_parameters, config_params=config, + mpu = mpu, ) return model, optimizer, lr_scheduler diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 16a5f0452da2..68b05884d13d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1793,3 +1793,70 @@ def forward(self, hidden_states): return torch.cat(output_chunks, dim=chunk_dim) return forward_fn(*input_tensors) + + +def recursive_to(device, item): + """ + Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, + tuple and dict structures. + """ + + if torch.is_tensor(item): + return item.to(device) + + elif isinstance(item, list): + for i, x in enumerate(item): + item[i] = recursive_to(device, x) + return item + + elif isinstance(item, tuple): + return tuple(recursive_to(device, list(item))) + + elif isinstance(item, dict): + for k, v in item.items(): + item[k] = recursive_to(device, v) + return item + + else: + return item + + +# tnone = torch.tensor([float('nan')]*batch_size) +def pipe_none_or_empty_to_torch(x, batch_size, device): + tnone = torch.tensor([-100] * batch_size).to(device) + tempty = torch.empty(0).to(device) + if x is None: + return tnone.to(device) + if x == (): + return tempty.to(device) + return x + + +def pipe_torch_to_none_or_empty(x, batch_size, device): + tnone = torch.tensor([-100] * batch_size).to(device) + # tempty = torch.empty(0).to(device) + # if torch.is_tensor(x): + # print(x.shape, x) + # else: + # print(x) + if torch.is_tensor(x) and x.shape[0] == batch_size: + if not x.numel(): + return () + # print(x.numel(), batch_size, x, tnone) + if x.shape == tnone.shape and all(x == tnone): + return None + return x + + +def pipe_encode_all(input, batch_size, device): + input = list(input) + for i, x in enumerate(input): + input[i] = pipe_none_or_empty_to_torch(x, batch_size, device) + return tuple(input) + + +def pipe_decode_all(input, batch_size, device): + input = list(input) + for i, x in enumerate(input): + input[i] = pipe_torch_to_none_or_empty(x, batch_size, device) + return tuple(input) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 6ed8037f17cd..f613628494c7 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -23,6 +23,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch.distributed.pipeline.sync import Pipe from torch.nn import CrossEntropyLoss from ...activations import ACT2FN @@ -39,7 +40,16 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import ( + PreTrainedModel, + find_pruneable_heads_and_indices, + pipe_decode_all, + pipe_encode_all, + pipe_none_or_empty_to_torch, + pipe_torch_to_none_or_empty, + prune_linear_layer, + recursive_to, +) from ...utils import logging from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_t5 import T5Config @@ -243,6 +253,7 @@ def forward(self, hidden_states): # convert into float16 if necessary if self.weight.dtype == torch.float16: hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states @@ -776,6 +787,153 @@ def _shift_right(self, input_ids): return shifted_input_ids +class T5StackPipeSegment(nn.Module): + def __init__( + self, + idx, + n_layers, + layer_module, + is_decoder, + layer_head_mask, + encoder_layer_head_mask, + output_hidden_states, + use_cache, + output_attentions, + all_hidden_states_add, + present_key_value_states_add, + all_attentions_add, + all_cross_attentions_add, + ): + super().__init__() + self.batch_id = -1 + self.idx = idx + self.n_layers = n_layers + self.layer_module = layer_module + self.is_decoder = is_decoder + self.layer_head_mask = layer_head_mask + self.encoder_layer_head_mask = encoder_layer_head_mask + # self.past_key_value = past_key_value + self.output_hidden_states = output_hidden_states + self.use_cache = use_cache + self.output_attentions = output_attentions + self.all_hidden_states_add = all_hidden_states_add + self.present_key_value_states_add = present_key_value_states_add + self.all_attentions_add = all_attentions_add + self.all_cross_attentions_add = all_cross_attentions_add + + def forward(self, inputs): + self.batch_id += 1 + # print(f"micro BS: {inputs[0].shape[0]}") + inputs = pipe_decode_all(inputs, inputs[0].shape[0], inputs[0].device) + ( + hidden_states, + attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + ) = inputs + idx = self.idx + + # self.past_key_value = recursive_to(inputs[0].device, self.past_key_value) + self.layer_head_mask = recursive_to(inputs[0].device, self.layer_head_mask) + self.encoder_layer_head_mask = recursive_to(inputs[0].device, self.encoder_layer_head_mask) + + # crazy restore: XXX: fix hardcoded numbers - self.n_layers for number of blocks - 2+2 should always be there + if past_key_values_p1 is not None and past_key_values_p2 is not None: + past_key_value = past_key_values_p1.chunk(self.n_layers, 1)[self.idx].chunk( + 2, 1 + ) + past_key_values_p2.chunk(self.n_layers, 1)[self.idx].chunk(2, 1) + else: + past_key_value = None + + # if self.past_key_value is not None: + # past_key_value = tuple(self.past_key_value[i][self.batch_id] for i in self.past_key_value) + # else: + # past_key_value=None + + # # restore None's if any + # position_bias = None if len(position_bias.shape) == 1 else position_bias + # encoder_hidden_states = None if len(encoder_hidden_states.shape) == 1 else encoder_hidden_states + # encoder_attention_mask = None if len(encoder_attention_mask.shape) == 1 else encoder_attention_mask + # encoder_decoder_position_bias = None if len(encoder_decoder_position_bias.shape) == 1 else encoder_decoder_position_bias + + # all_hidden_states = () if torch.is_tensor(all_hidden_states) and len(all_hidden_states.shape) == 1 else all_hidden_states + # present_key_value_states = () if torch.is_tensor(present_key_value_states) and len(present_key_value_states.shape) == 1 else present_key_value_states + # all_attentions = () if len(all_attentions.shape) == 1 else all_attentions + # all_cross_attentions = () if len(all_cross_attentions.shape) == 1 else all_cross_attentions + + if self.output_hidden_states: + self.all_hidden_states_add(hidden_states) + + layer_outputs = self.layer_module( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=self.layer_head_mask, + encoder_layer_head_mask=self.encoder_layer_head_mask, + past_key_value=past_key_value, + use_cache=self.use_cache, + output_attentions=self.output_attentions, + ) + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention weights), + # (self-attention position bias), (cross-attention weights), (cross-attention position bias) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] + # append next layer key value states + if self.use_cache: + # print(idx, self.batch_id) + # present_key_values_p1 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) + # present_key_values_p2 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) + present_key_values_p1 = torch.cat(present_key_value_state[0:2], 1) + present_key_values_p2 = torch.cat(present_key_value_state[2:4], 1) + self.present_key_value_states_add(present_key_value_state, idx, self.batch_id) + + if self.output_attentions: + self.all_attentions_add(layer_outputs[3]) + if self.is_decoder: + self.all_cross_attentions_add(layer_outputs[5]) + + # tnone = torch.tensor([float('nan')]*out_shape) + # tnone = torch.tensor([-100]*out_shape).to(hidden_states.device) + # position_bias = tnone if position_bias is None else position_bias + # encoder_hidden_states = tnone if encoder_hidden_states is None else encoder_hidden_states + # encoder_attention_mask = tnone if encoder_attention_mask is None else encoder_attention_mask + # encoder_decoder_position_bias = tnone if encoder_decoder_position_bias is None else encoder_decoder_position_bias + # all_hidden_states = tnone if all_hidden_states is None else all_hidden_states + # present_key_value_states = tnone if present_key_value_states is None else present_key_value_states + # all_attentions = tnone if all_attentions is None else all_attentions + # all_cross_attentions = tnone if all_cross_attentions is None else all_cross_attentions + + outputs = ( + hidden_states, + attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + ) + outputs = pipe_encode_all(outputs, hidden_states.shape[0], hidden_states.device) + return outputs + + class T5Stack(T5PreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) @@ -786,46 +944,25 @@ def __init__(self, config, embed_tokens=None): self.block = nn.ModuleList( [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] ) + + # self.is_pipeline = False + self.is_pipeline = True + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) self.init_weights() - # Model parallel self.model_parallel = False - self.device_map = None - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - # Load onto devices - for k, v in self.device_map.items(): - for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - # Set embed_tokens to first layer - self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device - self.final_layer_norm = self.final_layer_norm.to(self.last_device) - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def deparallelize(self): - self.model_parallel = False + self.pipeline_chunks = 0 self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") - torch.cuda.empty_cache() + self.pipeline_is_enabled = False + self.pipeline_batch_size = None + + def pipeline_params(self, chunks, device_map): + self.pipeline_chunks = chunks + self.device_map = device_map + self.pipeline_is_enabled = True def get_input_embeddings(self): return self.embed_tokens @@ -848,6 +985,9 @@ def forward( output_hidden_states=None, return_dict=None, ): + + # print(f"mini BS: {input_ids.shape[0]}") + # Model parallel if self.model_parallel: torch.cuda.set_device(self.first_device) @@ -896,8 +1036,8 @@ def forward( ) # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # if past_key_values is None: + # past_key_values = [None] * len(self.block) # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) @@ -919,71 +1059,160 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): - layer_head_mask = head_mask[i] - encoder_layer_head_mask = encoder_head_mask[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) - if layer_head_mask is not None: - layer_head_mask = layer_head_mask.to(hidden_states.device) - if encoder_layer_head_mask is not None: - encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - encoder_layer_head_mask=encoder_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, + if self.pipeline_is_enabled: + # handle batches (usually last) that are shorter than pipeline_chunks + if batch_size < self.pipeline_chunks: + # # XXX: Is it always last? so that we don't override user's chunks setting unless it's the last batch + self.pipeline_chunks = 1 + else: + # non-pipeline run is the same as chunks=1 batch size-wise + self.pipeline_chunks = 1 + + # PP + + n_layers = len(self.block) + # n_chunks = 4 + + def all_hidden_states_add(x): + nonlocal all_hidden_states + all_hidden_states += (x,) + + # handle batches (usually last) that are can't be equally divided by pipeline_chunks + + present_key_value_states = [[0 for x in range(self.pipeline_chunks)] for y in range(n_layers)] + + def present_key_value_states_add(x, block_id, micro_batch_id): + nonlocal present_key_value_states + # present_key_value_states += (x,) + present_key_value_states[block_id][micro_batch_id] = x + # present_key_value_states += (x,) + # print(x.shape for x in present_key_value_states) + # print(present_key_value_states) + + def all_attentions_add(x): + nonlocal all_attentions + all_attentions += (x,) + + def all_cross_attentions_add(x): + nonlocal all_cross_attentions + all_cross_attentions += (x,) + + # crazy flattening of 2 level tuples so that the batch dimension is first to be spliced upon and then restored on the other side + if past_key_values is not None: + x1 = tuple(past_key_values[i][j].to(input_ids.device) for i in range(len(past_key_values)) for j in [0, 1]) + # for i in x1: print(i.shape) + x2 = tuple(past_key_values[i][j].to(input_ids.device) for i in range(len(past_key_values)) for j in [2, 3]) + # for i in x2: print(i.shape) + past_key_values_p1 = torch.cat(x1, 1) + past_key_values_p2 = torch.cat(x2, 1) + # input = torch.cat(tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in range(len(past_key_values[i]))), 1) + else: + past_key_values_p1 = None + past_key_values_p2 = None + + # batch_size=2, blocks=self.n_layers, fixed=2 (2+2 keys) + present_key_values_p1 = torch.empty(batch_size, n_layers * 2).to(input_ids.device) + present_key_values_p2 = torch.empty(batch_size, n_layers * 2).to(input_ids.device) + + # rewrite the model after pre-trained weights were loaded + layers = [ + T5StackPipeSegment( + idx, + n_layers, + layer_module, + self.is_decoder, + head_mask[idx], + encoder_head_mask[idx], + output_hidden_states, + use_cache, + output_attentions, + all_hidden_states_add, + present_key_value_states_add, + all_attentions_add, + all_cross_attentions_add, ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention weights), - # (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) + for idx, layer_module in enumerate(self.block) + ] + # block_sequential = nn.Sequential(*layers) + + # for now don't enable the pipe + if self.pipeline_is_enabled: + + # print("using partitioning: ", dict(zip(devices, layer_splits))) + for device_id, layer_partition in self.device_map.items(): + for layer_id in layer_partition: + # print(f"{layer_id} => {device_id}") + layers[layer_id].to(device_id) + + block_sequential = nn.Sequential(*layers) + block_pipe = Pipe(block_sequential, chunks=self.pipeline_chunks, checkpoint="never") + else: + block_sequential = nn.Sequential(*layers) + + inputs = ( + hidden_states, + extended_attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + ) + # , all_hidden_states, present_key_value_states, all_attentions, all_cross_attentions) + inputs = pipe_encode_all(inputs, batch_size, input_ids.device) + + if self.pipeline_is_enabled: + outputs = block_pipe(inputs) + outputs = outputs.local_value() + outputs = recursive_to(input_ids.device, outputs) + else: + outputs = block_sequential(inputs) + + outputs = pipe_decode_all(outputs, batch_size, input_ids.device) + ( + hidden_states, + attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + ) = outputs hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) + # if present_key_values_p1 is not None: + # x1 = present_key_values_p1.chunk(n_layers, 1) + # finalx1 = tuple(x.chunk(2, 1) for x in x1) + + # #present_key_values_p1 = present_key_values_p1.chunk(n_layers, 1).chunk(2, 1) + # #present_key_values_p2 + + # if self.pipeline_is_enabled and present_key_value_states is not None and present_key_value_states[0][0] != 0: + if present_key_value_states is not None and present_key_value_states[0][0] != 0: + # print() + # reconstruct the flattened tensor to tuple of tuples of tensors + new_x = () + # deal with unpredictable potential last short batch + real_chunks = 0 + for i in range(self.pipeline_chunks): + if not present_key_value_states[0][i] == 0: + real_chunks += 1 + for block in present_key_value_states: + new_y = () + for j in (0, 1, 2, 3): + entries = tuple(block[i][j].to(input_ids.device) for i in range(real_chunks)) + new_y += (torch.cat(entries, 0),) + new_x += (new_y,) + present_key_value_states = new_x + # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1394,6 +1623,149 @@ def __init__(self, config): self.model_parallel = False self.device_map = None + self.pipeline_is_enabled = False + + def pipeline_enable(self, chunks, device_map, mpu=None): + logger.info(f"enabling pipeline with chunks={chunks}") + + # XXX: should be a separate function + import torch + + + try: + # will succeed if rpc has started (deepspeed launcher) + dist_world_size = torch.distributed.get_world_size() + dist_rank = torch.distributed.get_rank() + except: + dist_world_size = 1 + dist_rank = 0 + + #dist_world_size=1 + + if mpu is not None: + log_prefix = f"[p{dist_rank}]" + #logger.warn(f"{log_prefix} got MPU") + logger.warn(f"{log_prefix} DP group { mpu.get_data_parallel_group_device_ids() }") + + this_proc_device_ids = mpu.get_model_parallel_group_device_ids() + logger.warn(f"{log_prefix} PP group {this_proc_device_ids } [MPU]") + + # XXX: automate this: + # We must be getting the right groups already from get_model_parallel_group_device_ids() + # If we don't then deepspeed isn't getting the right groups - and it'd break + if dist_world_size == 4: + if dist_rank == 0: + this_proc_device_ids = [0, 1] + else: + this_proc_device_ids = [2, 3] + elif dist_world_size == 2: + if dist_rank == 0: + this_proc_device_ids = [0] + else: + this_proc_device_ids = [1] + + else: + log_prefix = f"[p0]" + this_proc_device_ids = list(range(torch.cuda.device_count())) + + logger.warn(f"{log_prefix} PP group {this_proc_device_ids }") + #logger.warn(f"{log_prefix} uses MP/PP device ids: {this_proc_device_ids}") + + n_gpus = len(this_proc_device_ids) + # XXX: restore this later + # if n_gpus < 2: + # raise ValueError("Need at least 2 gpus to use the pipeline") + + if device_map is not None: + logger.info(f"using user-provided device_map") + else: + + def make_device_map(n_gpus, n_layers): + print(f"making default device map: n_gpus={n_gpus}, n_layers={n_layers}") + devices = list(range(n_gpus)) + layer_ids = list(range(n_layers)) + # XXX: later will have a map - for now just roughly split + # XXX: probably should balance more so that the 0th gpu has the least number of layers, rather than the last one, because 0th gpu is already very busy + layer_splits = [ + layer_ids[i * n_layers // n_gpus : (i + 1) * n_layers // n_gpus] for i in range(n_gpus) + ] + return dict(zip(devices, layer_splits)) + + # XXX: for now assume encode/decoder symmetry - later fix to build each one separately + n_layers = len(self.encoder.block) + device_map = make_device_map(n_gpus, n_layers) + + # 2D parallel - i.e. deepspeed + pp + if mpu is not None: + # we need to assign the correct set of IDs for this process - that we get from MPU + # in case of 2D the user describes the device map only for the first group of DP + # and we need to re-assign the ids for the rest of the groups, so say a user passes a device map: + # 0:0-7, 1:7-14 + # for process rank 0 it remains that, but for process rank 1 it should become: + # 2:0-7, 3:7-14 + # and so on. + # MPU gives us the correct local MP group (this_proc_device_ids) + remapped_device_map = {} + for i, id in enumerate(device_map.keys()): + remapped_device_map[this_proc_device_ids[i]] = device_map[id] + self.device_map = remapped_device_map + else: + self.device_map = device_map + + self.pipeline_is_enabled = True + logger.warn(f"{log_prefix} PP partitions: {self.device_map}") + + # XXX: validate chunks is a good arg + + self.encoder.pipeline_params(chunks=chunks, device_map=self.device_map) + self.decoder.pipeline_params(chunks=chunks, device_map=self.device_map) + + # XXX for now hardcoded the RPC setup here - but it should happen in the trainer instead + import os + + import torch + from torch.distributed import rpc + + # dynamically check if rpc has been initialized already - i.e in case we have deepspeed as the launcher of 2D + # XXX: this needs to be parameterized/cleaned up + # try: + # # will succeed if rpc has started (deepspeed launcher) + # torch.distributed.get_world_size() + # except: + if 1: + os.environ.update({"MASTER_ADDR": "localhost"}) + os.environ.update({"MASTER_PORT": "10639"}) + rpc.init_rpc( + #"worker", + f"worker{dist_rank}", + #rank=0, + rank=dist_rank, + world_size=dist_world_size, + ) + num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 2 + device = torch.device("cuda") + + def pipeline_finalize(self): + # XXX: should reset the max counter + # reset_peak_stats() + import json + + import torch + + n_gpus = torch.cuda.device_count() + mem_map = {} + for id in range(n_gpus): + with torch.cuda.device(id): + # XXX: this doesn't seem to report the right thing - getting much lower numbers + mem_map[id] = torch.cuda.max_memory_allocated() >> 20 + + logger.info(f"peak memory usage per device in MBs:\n{json.dumps(mem_map, sort_keys=True, indent=4)}") + # reset for the next train/eval/predict stage + # XXX: probably should do in the trainer? + torch.cuda.reset_peak_memory_stats() + + # XXX: would be great to add gpu utilization stats as well + @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): self.device_map = ( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 009c9bff10ba..d866ff7e7118 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -284,6 +284,7 @@ def __init__( else: self.is_model_parallel = False + # Setup Sharded DDP training self.sharded_ddp = None if len(args.sharded_ddp) > 0: @@ -318,6 +319,64 @@ def __init__( ): self.place_model_on_device = False + # XXX: this is probably wrong - as it won't fit on device normally + if len(self.args.pipeline): + model = model.to(args.device) + + self.mpu = None + # XXX: for now hack over naive MP to have the same behavior + if len(self.args.pipeline): + # using range() syntax for upper boundary (i.e. not inclusive) + # --pipeline "chunks=5; device_map=0:0-10,1:10-20" + self.is_model_parallel = True + + # arg parser + pp_args = {} + pp_args_str = self.args.pipeline.split() + if len(pp_args_str): + for x in pp_args_str: + k,v = x.split("=") + pp_args[k] = v + + if "chunks" in pp_args: + pp_args["chunks"] = int(pp_args["chunks"]) + else: + # XXX: probably can try some smart dynamic default based on batch_size + pp_args["chunks"] = 2 + + if "device_map" in pp_args: + device_map_range_str = pp_args["device_map"].split(",") + device_map = {} + for x in device_map_range_str: + device_id, layers = x.split(":") + device_map[int(device_id)] = list(range(*map(int, layers.split("-")))) + pp_args["device_map"] = device_map + else: + pp_args["device_map"] = None + + if "n_gpus_per_mp" in pp_args: + pp_args["n_gpus_per_mp"] = int(pp_args["n_gpus_per_mp"]) + else: + # XXX: can try some smart dynamic default here based on total_n_gpus, + # if it's not 2D all gpus will be used + # if 2D half gpus should be a good default + pp_args["n_gpus_per_mp"] = 2 + + # 2D Parallel + if self.args.deepspeed: + from .integrations import MPU + self.mpu = MPU() + #n_gpus = torch.distributed.get_world_size() + # XXX: hardcoded for 2 gpus for PP/MP - needs to be configurable + #n_gpus_per_mp = n_gpus/2 + # at the moment experimenting with just 4 gpus - hence 2 gpus for MP|PP, 2 for DP + self.mpu.initialize_model_parallel(pp_args["n_gpus_per_mp"]) + + model.pipeline_enable(chunks=pp_args["chunks"], device_map=pp_args["device_map"], mpu=self.mpu) + + + + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -888,7 +947,7 @@ def train( delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE if self.args.deepspeed: - model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) + model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps, mpu=self.mpu) self.model = model.module self.model_wrapped = model # will get further wrapped in DDP self.deepspeed = model # DeepSpeedEngine object @@ -1114,6 +1173,9 @@ def train( # Clean the state at the end of training delattr(self, "_past") + if len(self.args.pipeline): + self.model.pipeline_finalize() + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: logger.info( @@ -1823,6 +1885,9 @@ def prediction_loop( if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + if len(self.args.pipeline): + self.model.pipeline_finalize() + return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) def _gather_and_numpify(self, tensors, name): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 90c04f89daa3..95d282868956 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -410,6 +410,13 @@ class TrainingArguments: ) debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"}) + pipeline: str = field( + default="", + metadata={ + "help": "Whether to enable Pipeline Parallelism and the value is pipeline params: 'chunks=5 device_map=0:1-10,1:10-20 n_gpus_per_pp=2" + }, + ) + dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} )