From 3018030695db7c2069da2b52b416e0898cdf6053 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 2 Jan 2021 11:48:38 -0800 Subject: [PATCH 1/6] bart goes parallel --- src/transformers/models/bart/modeling_bart.py | 267 +++++++++++++++++- src/transformers/models/t5/modeling_t5.py | 10 +- src/transformers/utils/logging.py | 2 + .../utils/model_parallel_utils.py | 262 +++++++++++++++++ tests/test_modeling_bart.py | 5 + tests/test_modeling_common.py | 77 +++-- 6 files changed, 569 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7f4af885d5b5..ca170f3e0c88 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -42,6 +42,14 @@ ) from ...modeling_utils import PreTrainedModel from ...utils import logging +from ...utils.model_parallel_utils import ( + init_device_map, + model_parallel_inputs_to_device, + model_parallel_inputs_to_specific_device, +) + +# log_name_device, +# print_layer_devices, from .configuration_bart import BartConfig @@ -109,13 +117,25 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) +PARALLELIZE_DOCSTRING = r""" +XXX: TODO modeling_t5.py:181 +""" +DEPARALLELIZE_DOCSTRING = r""" +""" + + +# XXX: overcome RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm(...) +# in BartEncoder.forward +# it works fine once on "cuda:0", then after it runs on "cuda:1" of the same code it corrupts `hidden_states` and self.fc1(hidden_states) blows up with the above error +# https://github.com/pytorch/fairseq/issues/2012 native might be slightly slower than apex FusedLayerNorm +# def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True): - try: - from apex.normalization import FusedLayerNorm + # try: + # from apex.normalization import FusedLayerNorm - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) - except ImportError: - pass + # return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + # except ImportError: + # pass return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) @@ -207,6 +227,7 @@ def __init__( def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + # @model_parallel_inputs_to_device def forward( self, hidden_states: torch.Tensor, @@ -219,9 +240,16 @@ def forward( # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder + + # print_layer_devices(self) + # hidden_states, key_value_states, past_key_value, attention_mask, output_attentions = model_parallel_inputs_to_specific_device(self.k_proj.weight.device, hidden_states, key_value_states, past_key_value, attention_mask, output_attentions) + # logger.info(f"MP {self.__class__.__name__} {log_name_device(key_value_states, 'key_value_states')}") + is_cross_attention = key_value_states is not None bsz, tgt_len, embed_dim = hidden_states.size() + # torch.cuda.set_device(hidden_states.device) + # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj @@ -230,6 +258,8 @@ def forward( key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.k_proj, 'self.k_proj')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(key_value_states, 'key_value_states')}") # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) @@ -329,6 +359,17 @@ def __init__(self, config: BartConfig): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = BartLayerNorm(self.embed_dim) + self.model_parallel = False + + def parallelize(self, device): + self.to(device) + self.model_parallel = True + + def deparallelize(self): + self.to("cpu") + self.model_parallel = False + + @model_parallel_inputs_to_device def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): """ Args: @@ -337,20 +378,33 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, out `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function. """ + # print_layer_devices(self) residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) + + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.self_attn, 'self.self_attn')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(hidden_states)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(attention_mask, 'attention_mask')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(output_attentions, 'output_attentions')}") hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states + + # torch.cuda.set_device(hidden_states.device) + # print(hidden_states) if not self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) + # print(hidden_states) residual = hidden_states if self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) + + # XXX: this is where apex.normalization.FusedLayerNorm causes explosion, and triggered by one of the above layer norm calls + # RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm(...)` hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) @@ -392,6 +446,17 @@ def __init__(self, config: BartConfig): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = BartLayerNorm(self.embed_dim) + self.model_parallel = False + + def parallelize(self, device): + self.to(device) + self.model_parallel = True + + def deparallelize(self): + self.to("cpu") + self.model_parallel = False + + @model_parallel_inputs_to_device def forward( self, hidden_states: torch.Tensor, @@ -416,6 +481,11 @@ def forward( if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.self_attn, 'self.self_attn')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(hidden_states)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(attention_mask, 'attention_mask')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(output_attentions, 'output_attentions')}") + # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None @@ -438,6 +508,7 @@ def forward( residual = hidden_states if self.normalize_before: hidden_states = self.encoder_attn_layer_norm(hidden_states) + # print_layer_devices(self) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None @@ -517,6 +588,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + self.model_parallel = False + self.first_device = "cpu" + self.last_device = "cpu" + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -678,6 +753,41 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.init_weights() + self.model_parallel = False + self.first_device = "cpu" + self.last_device = "cpu" + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # device_map = init_device_map(len(self.block), device_map) + self.first_device = f"cuda:{list(device_map.keys())[0]}" + self.last_device = f"cuda:{list(device_map.keys())[-1]}" + # logger.info(f"MP {self.__class__.__name__}. first device: {self.first_device}") + # logger.info(f"MP {self.__class__.__name__}. last device: {self.last_device}") + # Load onto devices + self.embed_tokens.to(self.first_device) + self.embed_positions.to(self.first_device) + for k, v in device_map.items(): + for layer in v: + self.layers[layer].parallelize(f"cuda:{k}") + self.layernorm_embedding.to(self.first_device) + if self.layer_norm is not None: + self.layer_norm.to(self.last_device) # XXX: first? + + self.model_parallel = True + + # @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # def deparallelize(self): + # self.to("cpu") + # self.first_device = "cpu" + # self.last_device = "cpu" + # for i in range(len(self.block)): + # self.block[i].deparallelize() + # self.embed_tokens.to("cpu") + # self.final_layer_norm.to("cpu") + # self.model_parallel = False + # torch.cuda.empty_cache() + def forward( self, input_ids=None, @@ -718,6 +828,21 @@ def forward( return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ + + # logger.info(f"MP {self.__class__.__name__}") + + if self.model_parallel: + ( + input_ids, + attention_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + ) = model_parallel_inputs_to_specific_device( + self.first_device, input_ids, attention_mask, inputs_embeds, output_attentions, output_hidden_states + ) + # torch.cuda.set_device(self.first_device) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -735,11 +860,20 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(input_ids)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.embed_tokens, 'embed_tokens')}") + if inputs_embeds is None: + if self.model_parallel: + self.embed_tokens.to(input_ids.device) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape) + # logger.info(f"MP {self.__class__.__name__} {log_name_device(inputs_embeds)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(embed_pos)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(attention_mask)}") + hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -751,7 +885,10 @@ def forward( encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + # for encoder_layer in self.layers: + for i, encoder_layer in enumerate(self.layers): + logger.info(f"MP {self.__class__.__name__} layer {i}") + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -817,6 +954,31 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.init_weights() + self.model_parallel = False + self.first_device = "cpu" + self.last_device = "cpu" + + # xxx: this is the same as the BartEncoder.parallelize + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # device_map = init_device_map(len(self.block), device_map) + self.first_device = f"cuda:{list(device_map.keys())[0]}" + self.last_device = f"cuda:{list(device_map.keys())[-1]}" + # logger.info(f"MP {self.__class__.__name__}. first device: {self.first_device}") + # logger.info(f"MP {self.__class__.__name__}. last device: {self.last_device}") + + # Load onto devices + self.embed_tokens.to(self.first_device) + self.embed_positions.to(self.first_device) + for k, v in device_map.items(): + for layer in v: + self.layers[layer].parallelize(f"cuda:{k}") + self.layernorm_embedding.to(self.first_device) + if self.layer_norm is not None: + self.layer_norm.to(self.last_device) # XXX: first? + + self.model_parallel = True + def forward( self, input_ids=None, @@ -880,6 +1042,34 @@ def forward( return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ + + if self.model_parallel: + ( + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + ) = model_parallel_inputs_to_specific_device( + self.first_device, + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + ) + + # getting RuntimeError: CUDA error: an illegal memory access was encountered w/o the next call + # torch.cuda.set_device(self.first_device) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -902,6 +1092,8 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: + if self.model_parallel: + self.embed_tokens.to(input_ids.device) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale # create causal mask @@ -968,6 +1160,7 @@ def forward( all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): + logger.info(f"MP {self.__class__.__name__} layer {idx}") # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1033,6 +1226,29 @@ def __init__(self, config: BartConfig): self.init_weights() + self.model_parallel = False + self.main_device = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + device_map = init_device_map(len(self.encoder.layers), len(self.decoder.layers), device_map) + encoder_device_map = device_map["encoder"] + decoder_device_map = device_map["decoder"] + + self.encoder.parallelize(encoder_device_map) + self.decoder.parallelize(decoder_device_map) + self.main_device = self.encoder.first_device + self.shared.to(self.main_device) + self.model_parallel = True + + # @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # def deparallelize(self): + # self.encoder.deparallelize() + # self.decoder.deparallelize() + # self.lm_head.to("cpu") + # self.model_parallel = False + # torch.cuda.empty_cache() + def get_input_embeddings(self): return self.shared @@ -1083,7 +1299,7 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + # logger.info(f"MP {self.__class__.__name__}. encoder") if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, @@ -1101,6 +1317,7 @@ def forward( attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) + # logger.info(f"MP {self.__class__.__name__}. decoder") # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -1118,6 +1335,11 @@ def forward( if not return_dict: return decoder_outputs + encoder_outputs + if self.model_parallel: + encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( + self.main_device, encoder_outputs, decoder_outputs + ) + return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, @@ -1150,6 +1372,31 @@ def __init__(self, config: BartConfig): self.init_weights() + self.model_parallel = False + self.main_device = None + self.last_device = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # device_map = init_device_map(len(self.encoder.block), device_map) + self.model.parallelize(device_map) + self.last_device = self.model.decoder.last_device + self.main_device = self.model.main_device + # logger.info(f"MP {self.__class__.__name__} MAIN DEVICE {self.main_device}") + self.final_logits_bias = self.final_logits_bias.to(self.main_device) + self.lm_head.to(self.main_device) + # logger.info(f"MP {self.__class__.__name__}. last device: {self.last_device}") + self.model_parallel = True + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.lm_head, 'self.lm_head')}") + + # @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # def deparallelize(self): + # self.encoder.deparallelize() + # self.decoder.deparallelize() + # self.lm_head.to("cpu") + # self.model_parallel = False + # torch.cuda.empty_cache() + def get_encoder(self): return self.model.get_encoder() @@ -1242,6 +1489,12 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) + + if self.model_parallel: + self.lm_head.to(self.main_device) + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.lm_head, 'self.lm_head')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(outputs[0], 'outputs[0]')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.final_logits_bias, 'self.final_logits_bias')}") lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias masked_lm_loss = None diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ce2be3c62ac..92c6f23943e0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -794,6 +794,7 @@ def parallelize(self, device_map=None): self.block[layer] = self.block[layer].to(cuda_device) # Set embed_tokens to first layer + # XXX: same gets set in forward 2nd time self.embed_tokens = self.embed_tokens.to(self.first_device) # Set final layer norm to last device self.final_layer_norm = self.final_layer_norm.to(self.last_device) @@ -951,9 +952,12 @@ def forward( # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) + devices = list(self.device_map.keys()) + for e, d in enumerate(devices): + if i == self.device_map[d][-1] and f"cuda:{d}" != self.last_device: + hidden_states = hidden_states.to( + f"cuda:{devices[e+1]}" + ) # again assumption that devices are not only ordered but also have a stride of 1 hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index ad514f707a0a..9ac852a7e83d 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -16,6 +16,7 @@ import logging import os +import sys import threading from logging import CRITICAL # NOQA from logging import DEBUG # NOQA @@ -78,6 +79,7 @@ def _configure_library_root_logger() -> None: # This library has already configured the library root logger. return _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py index 3a145df9868b..32910f7edd4c 100644 --- a/src/transformers/utils/model_parallel_utils.py +++ b/src/transformers/utils/model_parallel_utils.py @@ -13,9 +13,271 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from math import ceil +import torch + +def make_default_sub_device_map(n_layers): + """Returns a dictionary of layers distributed evenly across all devices.""" + + # XXX: in the future we can implement a smarter allocation based on the device memory, since: + # 1. cards can be of different memory-size (uncommon, but this developer has this setup) + # 2. also the first device of encoder is used as the main device which will have all the non-layer specific params on it and thus will take more memory, so ideally the default should put less layers on that device + # but, of course, users can customize it to their liking in their code. + # Except it is not possible to customize the map in Trainer-based scripts, like `finetune_trainer.py`, where the user can only switch --model_parallel flag on and no way to set the map. + n_gpus = torch.cuda.device_count() + + # XXX: this function splits the layers evenly across all devices, so that the end result is that each encoder and decoder share devices, as compared to creating maps where either of the two completely takes over one of the devices - need to measure which approach is more efficient - i.e. minimizes inter-device copying. + layers = list(range(n_layers)) + n_blocks = int(ceil(n_layers / n_gpus)) + layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)) + + return dict(zip(range(n_gpus), layers_list)) + + +def make_default_device_map(encoder_n_layers, decoder_n_layers): + return { + "encoder": make_default_sub_device_map(encoder_n_layers), + "decoder": make_default_sub_device_map(decoder_n_layers), + } + + +def validate_sub_device_map(n_layers, device_map, name): + where = f"device_map['{name}']" + possible_sub_device_map = make_default_sub_device_map(n_layers) + error_msg = f"here is a possible entry for {where}:\n{possible_sub_device_map}" + + # general format + gpu_ids = device_map.keys() + assert all(isinstance(x, int) for x in gpu_ids), ( + f"{where}: All keys much be integers, corresponding to available gpu IDS)\n" + error_msg + ) + layer_ids = [i for v in device_map.values() for i in v] + assert all(isinstance(x, int) for x in layer_ids), ( + f"{where}: Values must contain only integers, corresponding to layer numbers\n" + error_msg + ) + + # reality check + valid_gpu_ids = list(range(torch.cuda.device_count())) + wrong_gpu_ids = [x for x in gpu_ids if x not in valid_gpu_ids] + assert not len(wrong_gpu_ids), ( + f"All keys must correspond to available gpus IDs, but got: {wrong_gpu_ids}\n" + error_msg + ) + + duplicate_layer_ids = [i for i in set(layer_ids) if layer_ids.count(i) > 1] + assert not len(duplicate_layer_ids), ( + f"{where}: duplicate layer numbers detected: {duplicate_layer_ids}\n" + "Each layer number must be specified only once, remove duplicates" + ) + + valid_layer_ids = list(range(0, n_layers)) + missing_layer_ids = [i for i in valid_layer_ids if i not in layer_ids] + assert not len(missing_layer_ids), ( + f"{where}: missing layer numbers detected: {missing_layer_ids}\n" "Add missing layers to the device map." + ) + extra_layer_ids = [i for i in layer_ids if i not in valid_layer_ids] + assert not len(extra_layer_ids), ( + f"{where}: non-existing layer numbers detected: {extra_layer_ids}\n" + f"This {name} has only {n_layers} layers.\n" + "Remove extraneous layers from the device map.\n" + ) + + +def validate_device_map(encoder_n_layers, decoder_n_layers, device_map): + possible_device_map = make_default_device_map(encoder_n_layers, decoder_n_layers) + error_msg = f"invalid device_map format detected, here is a possible device map {possible_device_map}" + + assert "encoder" in device_map and "decoder" in device_map, error_msg + encoder_device_map = device_map["encoder"] + decoder_device_map = device_map["decoder"] + + assert isinstance(encoder_device_map, dict) and isinstance(decoder_device_map, dict), error_msg + + validate_sub_device_map(encoder_n_layers, encoder_device_map, "encoder") + validate_sub_device_map(decoder_n_layers, decoder_device_map, "decoder") + + +def init_device_map(encoder_n_layers, decoder_n_layers, device_map=None): + """ + - creates a device_map if none was passed + - validates that map is correct + + Args: + encoder_n_layers - number of encoder layers to remap + decoder_n_layers - number of decoder layers to remap + device_map - use this user-supplied map + """ + if device_map is None: + device_map = make_default_device_map(encoder_n_layers, decoder_n_layers) + validate_device_map(encoder_n_layers, decoder_n_layers, device_map) + return device_map + + +def log_name_device(var, fallbackname=None): # search from the outmost frame inwards + """ + This helper is useful for debug tracing of devices of variables, e.g.: + logger.info(f"MP {self.__class__.__name__} {log_name_device(attention_mask)}") + if it can't deduce the variable name (or finds wrong name), pass the name explicitly, e.g.: + logger.info(f"MP {self.__class__.__name__} {log_name_device(self.lm_head, 'self.lm_head')}") + """ + + if fallbackname is not None: + name = fallbackname + else: + for f in reversed(inspect.stack()): + name = "unknown" + names = [x for x, val in f.frame.f_locals.items() if val is var] + if len(names) > 0: + name = names[0] + break + if var is None: + return f"{name} val=None" + + device = None + try: + device = var.device + except AttributeError: + if hasattr(var, "parameters"): + device = get_layer_device(var) + return f"{name} {device}" + + +def get_layer_device(self): + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + return device + + +# def to_dev(self, input): +# try: +# device = next(self.parameters(recurse=True)).device +# except StopIteration: +# device = None + +# if device is None: +# raise ValueError(f"Can't find any params for {self.__class__}") +# print(f"manual switch to {device}") +# return input.to(device) + + +def recursive_to(device, item): + """ + Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, + tuple and dict structures. + """ + + if torch.is_tensor(item): + return item.to(device) + + elif isinstance(item, list): + for i, x in enumerate(item): + item[i] = recursive_to(device, x) + return item + + elif isinstance(item, tuple): + return tuple(recursive_to(device, list(item))) + + elif isinstance(item, dict): + for k, v in item.items(): + item[k] = recursive_to(device, v) + return item + + else: + return item + + +def model_parallel_inputs_to_device(func): + """ + This decorator is a noop unless self.model_parallel == True. + + It will try to find at least one parameter to read layer's .device from and then will automatically copy any inputs + to that device before `forward` is called. Use it as: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...) + + It will do its magical thing only if all params of this layer are on the same device. If it is not the case use + `model_parallel_inputs_to_specific_device` at the top of `forward` + """ + + def _call__mp(self, *input, **kwargs): + + if not hasattr(self, "model_parallel") or not self.model_parallel: + return func(self, *input, **kwargs) + + # get device of any of the param of this layer + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + + # print(f"layer device: {device}") + if device is not None: + # torch.cuda.set_device(device) + # print(f"auto-move inputs to {device}") + + input = recursive_to(device, input) + kwargs = recursive_to(device, kwargs) + + return func(self, *input, **kwargs) + + return _call__mp + + +def model_parallel_inputs_to_specific_device(device, *input): + """ + Similar to the model_parallel_inputs_to_device decorator, but this one is used for situations either when: 1. an + explicit call is desired (similar to `model.to()`) 2. the layer has params on mixed devices and therefore a wrong + device might get picked + + To use: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...): # get the desired device somewhere, e.g. a + specific param or a module attribute device = self.fc1.device input1, input2 = + model_parallel_inputs_to_specific_device(device, input1, input2) # this is the same as: input1 = input1.to(device) + input2 = input2.to(device) # but it works on variables that contain tensors but don't have `.to()` otherwise + """ + if device is None: + raise ValueError("device cannot be None") + # print(f"move specific inputs to {device}") + input = recursive_to(device, input) + # remove the need for the caller to perform "a, = foo(a)", + # which otherwise will make `a` a tuple when it might not be one + return input[0] if len(input) == 1 else input + + +# def model_parallel_call(self, *input, **kwargs): + +# # get device of any of the param of this layer +# try: +# device = next(self.parameters(recurse=True)).device +# except StopIteration: +# device = None + +# # print(f"layer device: {device}") +# if device is not None: +# # torch.cuda.set_device(device) +# input = recursive_to(device, input) +# kwargs = recursive_to(device, kwargs) + +# return nn.Module.__call__(self, *input, **kwargs) + + +def print_layer_devices(self): + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + print(f"device dump - looked up device {device}") + for n, p in self.named_parameters(): + print(f"{n}: {p.device}") + + +# XXX: still used by t5 + gpt2 so leave here for now +# will be removed once the other functions above have been integrated def assert_device_map(device_map, num_blocks): blocks = list(range(0, num_blocks)) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index f38816e095e2..e254852efc61 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -114,6 +114,9 @@ def __init__( self.bos_token_id = bos_token_id torch.manual_seed(0) + def get_large_model_config(self): + return BartConfig.from_pretrained("facebook/bart-base") + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( 3, @@ -218,10 +221,12 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () + all_parallelizable_model_classes = (BartModel, BartForConditionalGeneration) if is_torch_available() else () is_encoder_decoder = True test_pruning = False test_head_masking = False test_missing_keys = False + test_model_parallel = True def setUp(self): self.model_tester = BartModelTester(self) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2b720566539f..a50f611305b8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import gc import inspect import os.path import random @@ -24,6 +25,7 @@ from transformers import is_torch_available from transformers.file_utils import WEIGHTS_NAME from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device +from transformers.utils.model_parallel_utils import model_parallel_inputs_to_specific_device if is_torch_available(): @@ -1081,15 +1083,15 @@ def test_model_parallelization(self): if not self.test_model_parallel: return - import subprocess - + # a candidate for testing_utils def get_current_gpu_memory_use(): - run_process = subprocess.Popen( - "nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader", shell=True, stdout=subprocess.PIPE - ) + """ returns a list of cuda memory allocations per GPU in MBs""" + + per_device_memory = [] + for id in range(torch.cuda.device_count()): + with torch.cuda.device(id): + per_device_memory.append(torch.cuda.memory_allocated() >> 20) - memory_usage = run_process.stdout.read().decode("utf-8").strip() - per_device_memory = [int(memory) for memory in memory_usage.split("\n")] return per_device_memory # Needs a large model to see the difference. @@ -1098,39 +1100,44 @@ def get_current_gpu_memory_use(): for model_class in self.all_parallelizable_model_classes: torch.cuda.empty_cache() - # Retrieve initial memory usage (should be close to 0) - initial_memory = get_current_gpu_memory_use() + # 1. single gpu memory load + unload + memory measurements + # Retrieve initial memory usage (can easily be ~0.6-1.5GB if cuda-kernels have been preloaded by previous tests) + memory_at_start = get_current_gpu_memory_use() - # Put model on device - model = model_class(config.from_pretrained("gpt2")) + # Put model on device 0 and take a memory snapshot + model = model_class(config) model.to("cuda:0") - - # Retrieve the memory after the model is put on the device memory_after_model_load = get_current_gpu_memory_use() + # The memory use on device 0 should be higher than it was initially. + self.assertGreater(memory_after_model_load[0], memory_at_start[0]) + del model + gc.collect() torch.cuda.empty_cache() - # The memory use on that device should be higher than it was initially. - self.assertGreater(memory_after_model_load[0], initial_memory[0]) + # 2. MP test + # it's essential to re-calibrate the usage before the next stage + memory_at_start = get_current_gpu_memory_use() # Spread model layers over multiple devices - model = model_class(config.from_pretrained("gpt2")) + model = model_class(config) model.parallelize() memory_after_parallelization = get_current_gpu_memory_use() # Assert that the memory use on all devices is higher than it was when loaded only on CPU for n in range(torch.cuda.device_count()): - self.assertGreater(memory_after_parallelization[n], initial_memory[n]) + self.assertGreater(memory_after_parallelization[n], memory_at_start[n]) - # Assert that the memory use of the first device is lower than it was when the entire model was loaded on it + # Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it self.assertLess(memory_after_parallelization[0], memory_after_model_load[0]) - # Assert that the memory use of the second device is higher than it was when the entire model was loaded - # on the other device. + # Assert that the memory use of device 1 is higher than it was when the entire model was loaded + # on device 0 and device 1 wasn't used at all self.assertGreater(memory_after_parallelization[1], memory_after_model_load[1]) del model + gc.collect() torch.cuda.empty_cache() @require_torch_multi_gpu @@ -1143,22 +1150,12 @@ def test_model_parallel_equal_results(self): for model_class in self.all_parallelizable_model_classes: inputs_dict = self._prepare_for_class(inputs_dict, model_class) - def cast_to_device(dictionary, device): - output = {} - for k, v in dictionary.items(): - if isinstance(v, torch.Tensor): - output[k] = v.to(device) - else: - output[k] = v - - return output - model = model_class(config) - output = model(**cast_to_device(inputs_dict, "cpu")) - model.parallelize() + output = model(**model_parallel_inputs_to_specific_device("cpu", inputs_dict)) - parallel_output = model(**cast_to_device(inputs_dict, "cuda:0")) + model.parallelize() + parallel_output = model(**model_parallel_inputs_to_specific_device("cuda:0", inputs_dict)) for value, parallel_value in zip(output, parallel_output): if isinstance(value, torch.Tensor): @@ -1177,23 +1174,15 @@ def test_model_parallel_beam_search(self): ) config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if "decoder_input_ids" in inputs_dict: + inputs_dict.pop("decoder_input_ids") # Bart* breaks if it's passed for model_class in all_generative_and_parallelizable_model_classes: inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config) - def cast_to_device(dictionary, device): - output = {} - for k, v in dictionary.items(): - if isinstance(v, torch.Tensor): - output[k] = v.to(device) - else: - output[k] = v - - return output - model.parallelize() - model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) + model.generate(**model_parallel_inputs_to_specific_device("cuda:0", inputs_dict), num_beams=2) global_rng = random.Random() From 881ab15eb6f102ed16b5944a89c9e448567ed8d5 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 2 Jan 2021 12:21:03 -0800 Subject: [PATCH 2/6] logging auto-flush to its own PR --- src/transformers/utils/logging.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index 9ac852a7e83d..ad514f707a0a 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -16,7 +16,6 @@ import logging import os -import sys import threading from logging import CRITICAL # NOQA from logging import DEBUG # NOQA @@ -79,7 +78,6 @@ def _configure_library_root_logger() -> None: # This library has already configured the library root logger. return _default_handler = logging.StreamHandler() # Set sys.stderr as stream. - _default_handler.flush = sys.stderr.flush # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() From 820e6917865efccab6c4533c2957d8862ebcce22 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 2 Jan 2021 12:27:17 -0800 Subject: [PATCH 3/6] belongs to another PR --- src/transformers/models/t5/modeling_t5.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 92c6f23943e0..0ce2be3c62ac 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -794,7 +794,6 @@ def parallelize(self, device_map=None): self.block[layer] = self.block[layer].to(cuda_device) # Set embed_tokens to first layer - # XXX: same gets set in forward 2nd time self.embed_tokens = self.embed_tokens.to(self.first_device) # Set final layer norm to last device self.final_layer_norm = self.final_layer_norm.to(self.last_device) @@ -952,12 +951,9 @@ def forward( # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: - devices = list(self.device_map.keys()) - for e, d in enumerate(devices): - if i == self.device_map[d][-1] and f"cuda:{d}" != self.last_device: - hidden_states = hidden_states.to( - f"cuda:{devices[e+1]}" - ) # again assumption that devices are not only ordered but also have a stride of 1 + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) From 57606f94470f968629303bbec0402551021499af Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 2 Jan 2021 12:39:50 -0800 Subject: [PATCH 4/6] is_torch_available --- tests/test_modeling_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a50f611305b8..8b6388d2199e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,7 +25,10 @@ from transformers import is_torch_available from transformers.file_utils import WEIGHTS_NAME from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device -from transformers.utils.model_parallel_utils import model_parallel_inputs_to_specific_device + + +if is_torch_available(): + from transformers.utils.model_parallel_utils import model_parallel_inputs_to_specific_device if is_torch_available(): From d5d5fad88c5c87ba300276906094704452ddecdc Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jan 2021 11:40:00 -0800 Subject: [PATCH 5/6] style --- src/transformers/models/bart/modeling_bart.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index aa256304cf38..5ff09dd0eefe 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -123,6 +123,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] DEPARALLELIZE_DOCSTRING = r""" """ + class BartLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting From fe21c43745fcf3f7958c17c2ac461bd784094205 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 10:30:32 -0800 Subject: [PATCH 6/6] make --model_parallel work with Bart --- src/transformers/models/bart/modeling_bart.py | 9 +++++---- src/transformers/trainer.py | 8 ++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9ff504f81efd..e9efa63adeb9 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -500,6 +500,7 @@ def forward(self, hidden_states: torch.Tensor): class BartPretrainedModel(PreTrainedModel): config_class = BartConfig base_model_prefix = "model" + is_parallelizable = True def _init_weights(self, module): std = self.config.init_std @@ -715,8 +716,8 @@ def parallelize(self, device_map=None): for layer in v: self.layers[layer].parallelize(f"cuda:{k}") self.layernorm_embedding.to(self.first_device) - if self.layer_norm is not None: - self.layer_norm.to(self.last_device) # XXX: first? + # if self.layer_norm is not None: + # self.layer_norm.to(self.last_device) # XXX: first? self.model_parallel = True @@ -923,8 +924,8 @@ def parallelize(self, device_map=None): for layer in v: self.layers[layer].parallelize(f"cuda:{k}") self.layernorm_embedding.to(self.first_device) - if self.layer_norm is not None: - self.layer_norm.to(self.last_device) # XXX: first? + # if self.layer_norm is not None: + # self.layer_norm.to(self.last_device) # XXX: first? self.model_parallel = True diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index effa50b5a92e..d1a964eaaf45 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -267,6 +267,14 @@ def __init__( ) self.model_init = model_init + if self.args.model_parallel: + if model.is_parallelizable: + model.parallelize() + else: + raise ValueError( + f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + ) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset