diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 3c8b56e69e83..e9efa63adeb9 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -43,6 +43,14 @@ ) from ...modeling_utils import PreTrainedModel from ...utils import logging +from ...utils.model_parallel_utils import ( + init_device_map, + model_parallel_inputs_to_device, + model_parallel_inputs_to_specific_device, +) + +# log_name_device, +# print_layer_devices, from .configuration_bart import BartConfig @@ -102,6 +110,13 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) +PARALLELIZE_DOCSTRING = r""" +XXX: TODO modeling_t5.py:181 +""" +DEPARALLELIZE_DOCSTRING = r""" +""" + + class BartLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -153,6 +168,7 @@ def __init__( def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + # @model_parallel_inputs_to_device def forward( self, hidden_states: torch.Tensor, @@ -165,9 +181,16 @@ def forward( # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder + + # print_layer_devices(self) + # hidden_states, key_value_states, past_key_value, attention_mask, output_attentions = model_parallel_inputs_to_specific_device(self.k_proj.weight.device, hidden_states, key_value_states, past_key_value, attention_mask, output_attentions) + # logger.info(f"MP {self.__class__.__name__} {log_name_device(key_value_states, 'key_value_states')}") + is_cross_attention = key_value_states is not None bsz, tgt_len, embed_dim = hidden_states.size() + # torch.cuda.set_device(hidden_states.device) + # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj @@ -176,6 +199,8 @@ def forward( key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.k_proj, 'self.k_proj')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(key_value_states, 'key_value_states')}") # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) @@ -274,6 +299,17 @@ def __init__(self, config: BartConfig): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.model_parallel = False + + def parallelize(self, device): + self.to(device) + self.model_parallel = True + + def deparallelize(self): + self.to("cpu") + self.model_parallel = False + + @model_parallel_inputs_to_device def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): """ Args: @@ -284,6 +320,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, out Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. """ + # print_layer_devices(self) residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions @@ -339,6 +376,17 @@ def __init__(self, config: BartConfig): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.model_parallel = False + + def parallelize(self, device): + self.to(device) + self.model_parallel = True + + def deparallelize(self): + self.to("cpu") + self.model_parallel = False + + @model_parallel_inputs_to_device def forward( self, hidden_states: torch.Tensor, @@ -364,6 +412,11 @@ def forward( """ residual = hidden_states + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.self_attn, 'self.self_attn')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(hidden_states)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(attention_mask, 'attention_mask')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(output_attentions, 'output_attentions')}") + # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None @@ -447,6 +500,7 @@ def forward(self, hidden_states: torch.Tensor): class BartPretrainedModel(PreTrainedModel): config_class = BartConfig base_model_prefix = "model" + is_parallelizable = True def _init_weights(self, module): std = self.config.init_std @@ -459,6 +513,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + self.model_parallel = False + self.first_device = "cpu" + self.last_device = "cpu" + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -640,6 +698,41 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.init_weights() + self.model_parallel = False + self.first_device = "cpu" + self.last_device = "cpu" + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # device_map = init_device_map(len(self.block), device_map) + self.first_device = f"cuda:{list(device_map.keys())[0]}" + self.last_device = f"cuda:{list(device_map.keys())[-1]}" + # logger.info(f"MP {self.__class__.__name__}. first device: {self.first_device}") + # logger.info(f"MP {self.__class__.__name__}. last device: {self.last_device}") + # Load onto devices + self.embed_tokens.to(self.first_device) + self.embed_positions.to(self.first_device) + for k, v in device_map.items(): + for layer in v: + self.layers[layer].parallelize(f"cuda:{k}") + self.layernorm_embedding.to(self.first_device) + # if self.layer_norm is not None: + # self.layer_norm.to(self.last_device) # XXX: first? + + self.model_parallel = True + + # @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # def deparallelize(self): + # self.to("cpu") + # self.first_device = "cpu" + # self.last_device = "cpu" + # for i in range(len(self.block)): + # self.block[i].deparallelize() + # self.embed_tokens.to("cpu") + # self.final_layer_norm.to("cpu") + # self.model_parallel = False + # torch.cuda.empty_cache() + def forward( self, input_ids=None, @@ -680,6 +773,21 @@ def forward( return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ + + # logger.info(f"MP {self.__class__.__name__}") + + if self.model_parallel: + ( + input_ids, + attention_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + ) = model_parallel_inputs_to_specific_device( + self.first_device, input_ids, attention_mask, inputs_embeds, output_attentions, output_hidden_states + ) + # torch.cuda.set_device(self.first_device) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -697,11 +805,20 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(input_ids)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.embed_tokens, 'embed_tokens')}") + if inputs_embeds is None: + if self.model_parallel: + self.embed_tokens.to(input_ids.device) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape) + # logger.info(f"MP {self.__class__.__name__} {log_name_device(inputs_embeds)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(embed_pos)}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(attention_mask)}") + hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -713,7 +830,10 @@ def forward( encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + # for encoder_layer in self.layers: + for i, encoder_layer in enumerate(self.layers): + logger.info(f"MP {self.__class__.__name__} layer {i}") + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -784,6 +904,31 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.init_weights() + self.model_parallel = False + self.first_device = "cpu" + self.last_device = "cpu" + + # xxx: this is the same as the BartEncoder.parallelize + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # device_map = init_device_map(len(self.block), device_map) + self.first_device = f"cuda:{list(device_map.keys())[0]}" + self.last_device = f"cuda:{list(device_map.keys())[-1]}" + # logger.info(f"MP {self.__class__.__name__}. first device: {self.first_device}") + # logger.info(f"MP {self.__class__.__name__}. last device: {self.last_device}") + + # Load onto devices + self.embed_tokens.to(self.first_device) + self.embed_positions.to(self.first_device) + for k, v in device_map.items(): + for layer in v: + self.layers[layer].parallelize(f"cuda:{k}") + self.layernorm_embedding.to(self.first_device) + # if self.layer_norm is not None: + # self.layer_norm.to(self.last_device) # XXX: first? + + self.model_parallel = True + def forward( self, input_ids=None, @@ -847,6 +992,34 @@ def forward( return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ + + if self.model_parallel: + ( + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + ) = model_parallel_inputs_to_specific_device( + self.first_device, + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + ) + + # getting RuntimeError: CUDA error: an illegal memory access was encountered w/o the next call + # torch.cuda.set_device(self.first_device) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -869,6 +1042,8 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: + if self.model_parallel: + self.embed_tokens.to(input_ids.device) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale # create causal mask @@ -904,6 +1079,7 @@ def forward( all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): + logger.info(f"MP {self.__class__.__name__} layer {idx}") # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -990,6 +1166,29 @@ def __init__(self, config: BartConfig): self.init_weights() + self.model_parallel = False + self.main_device = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + device_map = init_device_map(len(self.encoder.layers), len(self.decoder.layers), device_map) + encoder_device_map = device_map["encoder"] + decoder_device_map = device_map["decoder"] + + self.encoder.parallelize(encoder_device_map) + self.decoder.parallelize(decoder_device_map) + self.main_device = self.encoder.first_device + self.shared.to(self.main_device) + self.model_parallel = True + + # @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # def deparallelize(self): + # self.encoder.deparallelize() + # self.decoder.deparallelize() + # self.lm_head.to("cpu") + # self.model_parallel = False + # torch.cuda.empty_cache() + def get_input_embeddings(self): return self.shared @@ -1040,7 +1239,7 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + # logger.info(f"MP {self.__class__.__name__}. encoder") if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, @@ -1058,6 +1257,7 @@ def forward( attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) + # logger.info(f"MP {self.__class__.__name__}. decoder") # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -1075,6 +1275,11 @@ def forward( if not return_dict: return decoder_outputs + encoder_outputs + if self.model_parallel: + encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( + self.main_device, encoder_outputs, decoder_outputs + ) + return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, @@ -1107,6 +1312,31 @@ def __init__(self, config: BartConfig): self.init_weights() + self.model_parallel = False + self.main_device = None + self.last_device = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # device_map = init_device_map(len(self.encoder.block), device_map) + self.model.parallelize(device_map) + self.last_device = self.model.decoder.last_device + self.main_device = self.model.main_device + # logger.info(f"MP {self.__class__.__name__} MAIN DEVICE {self.main_device}") + self.final_logits_bias = self.final_logits_bias.to(self.main_device) + self.lm_head.to(self.main_device) + # logger.info(f"MP {self.__class__.__name__}. last device: {self.last_device}") + self.model_parallel = True + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.lm_head, 'self.lm_head')}") + + # @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + # def deparallelize(self): + # self.encoder.deparallelize() + # self.decoder.deparallelize() + # self.lm_head.to("cpu") + # self.model_parallel = False + # torch.cuda.empty_cache() + def get_encoder(self): return self.model.get_encoder() @@ -1182,6 +1412,12 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) + + if self.model_parallel: + self.lm_head.to(self.main_device) + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.lm_head, 'self.lm_head')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(outputs[0], 'outputs[0]')}") + # logger.info(f"MP {self.__class__.__name__} {log_name_device(self.final_logits_bias, 'self.final_logits_bias')}") lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias masked_lm_loss = None diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index effa50b5a92e..d1a964eaaf45 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -267,6 +267,14 @@ def __init__( ) self.model_init = model_init + if self.args.model_parallel: + if model.is_parallelizable: + model.parallelize() + else: + raise ValueError( + f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + ) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py index 3a145df9868b..32910f7edd4c 100644 --- a/src/transformers/utils/model_parallel_utils.py +++ b/src/transformers/utils/model_parallel_utils.py @@ -13,9 +13,271 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from math import ceil +import torch + +def make_default_sub_device_map(n_layers): + """Returns a dictionary of layers distributed evenly across all devices.""" + + # XXX: in the future we can implement a smarter allocation based on the device memory, since: + # 1. cards can be of different memory-size (uncommon, but this developer has this setup) + # 2. also the first device of encoder is used as the main device which will have all the non-layer specific params on it and thus will take more memory, so ideally the default should put less layers on that device + # but, of course, users can customize it to their liking in their code. + # Except it is not possible to customize the map in Trainer-based scripts, like `finetune_trainer.py`, where the user can only switch --model_parallel flag on and no way to set the map. + n_gpus = torch.cuda.device_count() + + # XXX: this function splits the layers evenly across all devices, so that the end result is that each encoder and decoder share devices, as compared to creating maps where either of the two completely takes over one of the devices - need to measure which approach is more efficient - i.e. minimizes inter-device copying. + layers = list(range(n_layers)) + n_blocks = int(ceil(n_layers / n_gpus)) + layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)) + + return dict(zip(range(n_gpus), layers_list)) + + +def make_default_device_map(encoder_n_layers, decoder_n_layers): + return { + "encoder": make_default_sub_device_map(encoder_n_layers), + "decoder": make_default_sub_device_map(decoder_n_layers), + } + + +def validate_sub_device_map(n_layers, device_map, name): + where = f"device_map['{name}']" + possible_sub_device_map = make_default_sub_device_map(n_layers) + error_msg = f"here is a possible entry for {where}:\n{possible_sub_device_map}" + + # general format + gpu_ids = device_map.keys() + assert all(isinstance(x, int) for x in gpu_ids), ( + f"{where}: All keys much be integers, corresponding to available gpu IDS)\n" + error_msg + ) + layer_ids = [i for v in device_map.values() for i in v] + assert all(isinstance(x, int) for x in layer_ids), ( + f"{where}: Values must contain only integers, corresponding to layer numbers\n" + error_msg + ) + + # reality check + valid_gpu_ids = list(range(torch.cuda.device_count())) + wrong_gpu_ids = [x for x in gpu_ids if x not in valid_gpu_ids] + assert not len(wrong_gpu_ids), ( + f"All keys must correspond to available gpus IDs, but got: {wrong_gpu_ids}\n" + error_msg + ) + + duplicate_layer_ids = [i for i in set(layer_ids) if layer_ids.count(i) > 1] + assert not len(duplicate_layer_ids), ( + f"{where}: duplicate layer numbers detected: {duplicate_layer_ids}\n" + "Each layer number must be specified only once, remove duplicates" + ) + + valid_layer_ids = list(range(0, n_layers)) + missing_layer_ids = [i for i in valid_layer_ids if i not in layer_ids] + assert not len(missing_layer_ids), ( + f"{where}: missing layer numbers detected: {missing_layer_ids}\n" "Add missing layers to the device map." + ) + extra_layer_ids = [i for i in layer_ids if i not in valid_layer_ids] + assert not len(extra_layer_ids), ( + f"{where}: non-existing layer numbers detected: {extra_layer_ids}\n" + f"This {name} has only {n_layers} layers.\n" + "Remove extraneous layers from the device map.\n" + ) + + +def validate_device_map(encoder_n_layers, decoder_n_layers, device_map): + possible_device_map = make_default_device_map(encoder_n_layers, decoder_n_layers) + error_msg = f"invalid device_map format detected, here is a possible device map {possible_device_map}" + + assert "encoder" in device_map and "decoder" in device_map, error_msg + encoder_device_map = device_map["encoder"] + decoder_device_map = device_map["decoder"] + + assert isinstance(encoder_device_map, dict) and isinstance(decoder_device_map, dict), error_msg + + validate_sub_device_map(encoder_n_layers, encoder_device_map, "encoder") + validate_sub_device_map(decoder_n_layers, decoder_device_map, "decoder") + + +def init_device_map(encoder_n_layers, decoder_n_layers, device_map=None): + """ + - creates a device_map if none was passed + - validates that map is correct + + Args: + encoder_n_layers - number of encoder layers to remap + decoder_n_layers - number of decoder layers to remap + device_map - use this user-supplied map + """ + if device_map is None: + device_map = make_default_device_map(encoder_n_layers, decoder_n_layers) + validate_device_map(encoder_n_layers, decoder_n_layers, device_map) + return device_map + + +def log_name_device(var, fallbackname=None): # search from the outmost frame inwards + """ + This helper is useful for debug tracing of devices of variables, e.g.: + logger.info(f"MP {self.__class__.__name__} {log_name_device(attention_mask)}") + if it can't deduce the variable name (or finds wrong name), pass the name explicitly, e.g.: + logger.info(f"MP {self.__class__.__name__} {log_name_device(self.lm_head, 'self.lm_head')}") + """ + + if fallbackname is not None: + name = fallbackname + else: + for f in reversed(inspect.stack()): + name = "unknown" + names = [x for x, val in f.frame.f_locals.items() if val is var] + if len(names) > 0: + name = names[0] + break + if var is None: + return f"{name} val=None" + + device = None + try: + device = var.device + except AttributeError: + if hasattr(var, "parameters"): + device = get_layer_device(var) + return f"{name} {device}" + + +def get_layer_device(self): + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + return device + + +# def to_dev(self, input): +# try: +# device = next(self.parameters(recurse=True)).device +# except StopIteration: +# device = None + +# if device is None: +# raise ValueError(f"Can't find any params for {self.__class__}") +# print(f"manual switch to {device}") +# return input.to(device) + + +def recursive_to(device, item): + """ + Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, + tuple and dict structures. + """ + + if torch.is_tensor(item): + return item.to(device) + + elif isinstance(item, list): + for i, x in enumerate(item): + item[i] = recursive_to(device, x) + return item + + elif isinstance(item, tuple): + return tuple(recursive_to(device, list(item))) + + elif isinstance(item, dict): + for k, v in item.items(): + item[k] = recursive_to(device, v) + return item + + else: + return item + + +def model_parallel_inputs_to_device(func): + """ + This decorator is a noop unless self.model_parallel == True. + + It will try to find at least one parameter to read layer's .device from and then will automatically copy any inputs + to that device before `forward` is called. Use it as: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...) + + It will do its magical thing only if all params of this layer are on the same device. If it is not the case use + `model_parallel_inputs_to_specific_device` at the top of `forward` + """ + + def _call__mp(self, *input, **kwargs): + + if not hasattr(self, "model_parallel") or not self.model_parallel: + return func(self, *input, **kwargs) + + # get device of any of the param of this layer + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + + # print(f"layer device: {device}") + if device is not None: + # torch.cuda.set_device(device) + # print(f"auto-move inputs to {device}") + + input = recursive_to(device, input) + kwargs = recursive_to(device, kwargs) + + return func(self, *input, **kwargs) + + return _call__mp + + +def model_parallel_inputs_to_specific_device(device, *input): + """ + Similar to the model_parallel_inputs_to_device decorator, but this one is used for situations either when: 1. an + explicit call is desired (similar to `model.to()`) 2. the layer has params on mixed devices and therefore a wrong + device might get picked + + To use: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...): # get the desired device somewhere, e.g. a + specific param or a module attribute device = self.fc1.device input1, input2 = + model_parallel_inputs_to_specific_device(device, input1, input2) # this is the same as: input1 = input1.to(device) + input2 = input2.to(device) # but it works on variables that contain tensors but don't have `.to()` otherwise + """ + if device is None: + raise ValueError("device cannot be None") + # print(f"move specific inputs to {device}") + input = recursive_to(device, input) + # remove the need for the caller to perform "a, = foo(a)", + # which otherwise will make `a` a tuple when it might not be one + return input[0] if len(input) == 1 else input + + +# def model_parallel_call(self, *input, **kwargs): + +# # get device of any of the param of this layer +# try: +# device = next(self.parameters(recurse=True)).device +# except StopIteration: +# device = None + +# # print(f"layer device: {device}") +# if device is not None: +# # torch.cuda.set_device(device) +# input = recursive_to(device, input) +# kwargs = recursive_to(device, kwargs) + +# return nn.Module.__call__(self, *input, **kwargs) + + +def print_layer_devices(self): + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + print(f"device dump - looked up device {device}") + for n, p in self.named_parameters(): + print(f"{n}: {p.device}") + + +# XXX: still used by t5 + gpt2 so leave here for now +# will be removed once the other functions above have been integrated def assert_device_map(device_map, num_blocks): blocks = list(range(0, num_blocks)) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 1100c893ae27..cfaffda4a54b 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -106,6 +106,9 @@ def __init__( self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id + def get_large_model_config(self): + return BartConfig.from_pretrained("facebook/bart-base") + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( @@ -391,6 +394,7 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () + all_parallelizable_model_classes = (BartModel, BartForConditionalGeneration) if is_torch_available() else () is_encoder_decoder = True test_pruning = False test_head_masking = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e33efd34b45c..8b6388d2199e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,6 +27,10 @@ from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device +if is_torch_available(): + from transformers.utils.model_parallel_utils import model_parallel_inputs_to_specific_device + + if is_torch_available(): import numpy as np import torch @@ -1149,22 +1153,12 @@ def test_model_parallel_equal_results(self): for model_class in self.all_parallelizable_model_classes: inputs_dict = self._prepare_for_class(inputs_dict, model_class) - def cast_to_device(dictionary, device): - output = {} - for k, v in dictionary.items(): - if isinstance(v, torch.Tensor): - output[k] = v.to(device) - else: - output[k] = v - - return output - model = model_class(config) - output = model(**cast_to_device(inputs_dict, "cpu")) - model.parallelize() + output = model(**model_parallel_inputs_to_specific_device("cpu", inputs_dict)) - parallel_output = model(**cast_to_device(inputs_dict, "cuda:0")) + model.parallelize() + parallel_output = model(**model_parallel_inputs_to_specific_device("cuda:0", inputs_dict)) for value, parallel_value in zip(output, parallel_output): if isinstance(value, torch.Tensor): @@ -1183,23 +1177,15 @@ def test_model_parallel_beam_search(self): ) config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if "decoder_input_ids" in inputs_dict: + inputs_dict.pop("decoder_input_ids") # Bart* breaks if it's passed for model_class in all_generative_and_parallelizable_model_classes: inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config) - def cast_to_device(dictionary, device): - output = {} - for k, v in dictionary.items(): - if isinstance(v, torch.Tensor): - output[k] = v.to(device) - else: - output[k] = v - - return output - model.parallelize() - model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) + model.generate(**model_parallel_inputs_to_specific_device("cuda:0", inputs_dict), num_beams=2) global_rng = random.Random()