diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 3b096c36db8e..25dbbf32b072 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -40,7 +40,11 @@ ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging -from ...utils.model_parallel_utils import assert_device_map, get_device_map +from ...utils.model_parallel_utils import ( + init_device_map, + model_parallel_inputs_to_device, + model_parallel_inputs_to_specific_device, +) from .configuration_t5 import T5Config @@ -597,6 +601,17 @@ def __init__(self, config, has_relative_attention_bias=False): self.layer.append(T5LayerFF(config)) + self.model_parallel = False + + def block_parallelize(self, device): + self.to(device) + self.model_parallel = True + + def block_deparallelize(self): + self.to("cpu") + self.model_parallel = False + + @model_parallel_inputs_to_device def forward( self, hidden_states, @@ -776,41 +791,34 @@ def __init__(self, config, embed_tokens=None): self.dropout = nn.Dropout(config.dropout_rate) self.init_weights() - # Model parallel + self.model_parallel = False - self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) + device_map = init_device_map(len(self.block), device_map) + self.first_device = f"cuda:{ list(device_map.keys())[0] }" + self.last_device = f"cuda:{ list(device_map.keys())[-1] }" # Load onto devices - for k, v in self.device_map.items(): + for k, v in device_map.items(): for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - # Set embed_tokens to first layer - self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device - self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.block[layer].block_parallelize(f"cuda:{k}") + self.embed_tokens.to(self.first_device) + self.final_layer_norm.to(self.last_device) + self.model_parallel = True - @add_start_docstrings(PARALLELIZE_DOCSTRING) + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): - self.model_parallel = False - self.device_map = None + self.to("cpu") self.first_device = "cpu" self.last_device = "cpu" for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") + self.block[i].block_deparallelize() + self.embed_tokens.to("cpu") + self.final_layer_norm.to("cpu") + self.model_parallel = False torch.cuda.empty_cache() def get_input_embeddings(self): @@ -833,10 +841,6 @@ def forward( output_hidden_states=None, return_dict=None, ): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(self.first_device) - self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -904,20 +908,6 @@ def forward( hidden_states = self.dropout(inputs_embeds) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -952,12 +942,6 @@ def forward( if self.is_decoder: all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1161,30 +1145,22 @@ def __init__(self, config: T5Config): self.init_weights() - # Model parallel self.model_parallel = False - self.device_map = None + self.main_device = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) + device_map = init_device_map(len(self.encoder.block), device_map) + self.encoder.parallelize(device_map) + self.decoder.parallelize(device_map) + self.main_device = self.encoder.first_device self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): self.encoder.deparallelize() self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") self.model_parallel = False - self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): @@ -1265,18 +1241,6 @@ def forward( ) hidden_states = encoder_outputs[0] - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) # Decode decoder_outputs = self.decoder( @@ -1296,6 +1260,11 @@ def forward( if not return_dict: return decoder_outputs + encoder_outputs + if self.model_parallel: + encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( + self.main_device, encoder_outputs, decoder_outputs + ) + return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, @@ -1341,32 +1310,26 @@ def __init__(self, config): self.init_weights() - # Model parallel self.model_parallel = False - self.device_map = None + self.first_device = None + self.last_device = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.decoder.first_device) + device_map = init_device_map(len(self.encoder.block), device_map) + self.encoder.parallelize(device_map) + self.decoder.parallelize(device_map) + self.first_device = self.encoder.first_device + self.last_device = self.decoder.last_device + self.lm_head.to(self.first_device) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): self.encoder.deparallelize() self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.lm_head = self.lm_head.to("cpu") + self.lm_head.to("cpu") self.model_parallel = False - self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): @@ -1456,9 +1419,6 @@ def forward( hidden_states = encoder_outputs[0] - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) @@ -1472,17 +1432,6 @@ def forward( if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) - # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -1498,13 +1447,13 @@ def forward( return_dict=return_dict, ) - sequence_output = decoder_outputs[0] - - # Set device for model parallelism if self.model_parallel: - torch.cuda.set_device(self.encoder.first_device) - self.lm_head = self.lm_head.to(self.encoder.first_device) - sequence_output = sequence_output.to(self.lm_head.weight.device) + encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( + self.first_device, encoder_outputs, decoder_outputs + ) + # self.lm_head.to(self.first_device) + + sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab @@ -1596,23 +1545,18 @@ def __init__(self, config: T5Config): self.init_weights() + self.model_parallel = False + @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) + device_map = init_device_map(len(self.encoder.block), device_map) + self.encoder.parallelize(device_map) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): self.encoder.deparallelize() - self.encoder = self.encoder.to("cpu") self.model_parallel = False - self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index effa50b5a92e..d1a964eaaf45 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -267,6 +267,14 @@ def __init__( ) self.model_init = model_init + if self.args.model_parallel: + if model.is_parallelizable: + model.parallelize() + else: + raise ValueError( + f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + ) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py index 3a145df9868b..fedd0c3d3eb4 100644 --- a/src/transformers/utils/model_parallel_utils.py +++ b/src/transformers/utils/model_parallel_utils.py @@ -15,8 +15,10 @@ from math import ceil +import torch -def assert_device_map(device_map, num_blocks): + +def validate_device_map(device_map, num_blocks): blocks = list(range(0, num_blocks)) device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] @@ -45,6 +47,127 @@ def assert_device_map(device_map, num_blocks): ) +def make_default_device_map(n_layers): + """Returns a dictionary of layers distributed evenly across all devices.""" + n_gpus = torch.cuda.device_count() + layers = list(range(n_layers)) + n_blocks = int(ceil(n_layers / n_gpus)) + layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)) + + return dict(zip(range(n_gpus), layers_list)) + + +def init_device_map(n_layers, device_map=None): + """ + - creates a device_map if none was passed + - validates that map is correct + + Args: + n_layers - how many total layers to remap + """ + if device_map is None: + device_map = make_default_device_map(n_layers) + validate_device_map(device_map, n_layers) + return device_map + + +def get_layer_device(self): + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + return device + + +def recursive_to(device, item): + """ + Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, + tuple and dict structures. + """ + + if torch.is_tensor(item): + return item.to(device) + + elif isinstance(item, list): + for i, x in enumerate(item): + item[i] = recursive_to(device, x) + return item + + elif isinstance(item, tuple): + return tuple(recursive_to(device, list(item))) + + elif isinstance(item, dict): + for k, v in item.items(): + item[k] = recursive_to(device, v) + return item + + else: + return item + + +def model_parallel_inputs_to_device(func): + """ + This decorator is a noop unless self.model_parallel == True. + + It will try to find at least one parameter to read layer's .device from and then will automatically copy any inputs + to that device before `forward` is called. Use it as: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...) + + It will do its magical thing only if all params of this layer are on the same device. If it is not the case use + `model_parallel_inputs_to_specific_device` at the top of `forward` + """ + + def _call__mp(self, *input, **kwargs): + + if not hasattr(self, "model_parallel") or not self.model_parallel: + return func(self, *input, **kwargs) + + # get device of any of the param of this layer + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + + # print(f"layer device: {device}") + if device is not None: + # torch.cuda.set_device(device) + # print(f"auto-move inputs to {device}") + + input = recursive_to(device, input) + kwargs = recursive_to(device, kwargs) + + return func(self, *input, **kwargs) + + return _call__mp + + +def model_parallel_inputs_to_specific_device(device, *input): + """ + Similar to the model_parallel_inputs_to_device decorator, but this one is used for situations either when: 1. an + explicit call is desired (similar to `model.to()`) 2. the layer has params on mixed devices and therefore a wrong + device might get picked + + To use: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...): # get the desired device somewhere, e.g. a + specific param or a module attribute device = self.fc1.device input1, input2 = + model_parallel_inputs_to_specific_device(device, input1, input2) # this is the same as: input1 = input1.to(device) + input2 = input2.to(device) # but it works on variables that contain tensors but don't have `.to()` otherwise + """ + if device is None: + raise ValueError("device cannot be None") + # print(f"move specific inputs to {device}") + input = recursive_to(device, input) + # remove the need for the caller to perform "a, = foo(a)", + # which otherwise will make `a` a tuple when it might not be one + return input[0] if len(input) == 1 else input + + +# XXX: still used by gpt2 so leave here for now +assert_device_map = validate_device_map + + def get_device_map(n_layers, devices): """Returns a dictionary of layers distributed evenly across all devices.""" layers = list(range(n_layers))