From 3620a52b3738b50d625d8e06e52377303caea717 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 26 Dec 2020 20:35:38 -0800 Subject: [PATCH 1/6] do not sort devices by number --- src/transformers/models/t5/modeling_t5.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ce2be3c62ac..81234fd60ef6 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -785,17 +785,15 @@ def parallelize(self, device_map=None): ) assert_device_map(self.device_map, len(self.block)) self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.first_device = "cpu" if "cpu" in self.device_map.keys() else f"cuda:{ list(self.device_map.keys())[0] }" + self.last_device = f"cuda:{ list(self.device_map.keys())[-1] }" # Load onto devices for k, v in self.device_map.items(): for layer in v: cuda_device = "cuda:" + str(k) self.block[layer] = self.block[layer].to(cuda_device) - # Set embed_tokens to first layer self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device self.final_layer_norm = self.final_layer_norm.to(self.last_device) @add_start_docstrings(PARALLELIZE_DOCSTRING) From 07bf1ac93347ba6e513cc32cb9b9e3e77034ce15 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 26 Dec 2020 21:25:06 -0800 Subject: [PATCH 2/6] remove duplicate; recode for any stride in device numbers --- src/transformers/models/t5/modeling_t5.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 81234fd60ef6..0ee719db0444 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -831,7 +831,6 @@ def forward( # Model parallel if self.model_parallel: torch.cuda.set_device(self.first_device) - self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -949,9 +948,10 @@ def forward( # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) + devices = list(self.device_map.keys()) + for e, d in enumerate(devices): + if i == self.device_map[d][-1] and f"cuda:{d}" != self.last_device: + hidden_states = hidden_states.to(f"cuda:{devices[e+1]}") hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) From 5851730dd6186b121c41c55ee45d421b876f6947 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 26 Dec 2020 22:51:14 -0800 Subject: [PATCH 3/6] one out, one in --- src/transformers/models/t5/modeling_t5.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ee719db0444..b9260398cee4 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -901,9 +901,7 @@ def forward( # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + # Ensure that the layer_module args are on the same device as hidden_states if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -912,6 +910,9 @@ def forward( encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if all_hidden_states is not None: + all_hidden_states = all_hidden_states.to(hidden_states.device) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From da46cac228097ad80558710f99e40157cd16b1c6 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 27 Dec 2020 18:53:53 -0800 Subject: [PATCH 4/6] auto-relocate-inputs; refactor --- src/transformers/models/t5/modeling_t5.py | 153 +++++------------- .../utils/model_parallel_utils.py | 73 ++++++++- tests/test_modeling_t5.py | 1 + 3 files changed, 114 insertions(+), 113 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b9260398cee4..c83734cedc7a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -40,7 +40,7 @@ ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging -from ...utils.model_parallel_utils import assert_device_map, get_device_map +from ...utils.model_parallel_utils import init_device_map, model_parallel_inputs_to_device from .configuration_t5 import T5Config @@ -595,6 +595,17 @@ def __init__(self, config, has_relative_attention_bias=False): self.layer.append(T5LayerFF(config)) + self.model_parallel = False + + def block_parallelize(self, device): + self.to(device) + self.model_parallel = True + + def block_deparallelize(self): + self.to("cpu") + self.model_parallel = False + + @model_parallel_inputs_to_device def forward( self, hidden_states, @@ -773,39 +784,34 @@ def __init__(self, config, embed_tokens=None): self.dropout = nn.Dropout(config.dropout_rate) self.init_weights() - # Model parallel + self.model_parallel = False - self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else f"cuda:{ list(self.device_map.keys())[0] }" - self.last_device = f"cuda:{ list(self.device_map.keys())[-1] }" + device_map = init_device_map(len(self.block), device_map) + self.first_device = f"cuda:{ list(device_map.keys())[0] }" + self.last_device = f"cuda:{ list(device_map.keys())[-1] }" # Load onto devices - for k, v in self.device_map.items(): + for k, v in device_map.items(): for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - self.embed_tokens = self.embed_tokens.to(self.first_device) - self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.block[layer].block_parallelize(f"cuda:{k}") + self.embed_tokens.to(self.first_device) + self.final_layer_norm.to(self.last_device) + self.model_parallel = True - @add_start_docstrings(PARALLELIZE_DOCSTRING) + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): - self.model_parallel = False - self.device_map = None + self.to("cpu") self.first_device = "cpu" self.last_device = "cpu" for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") + self.block[i].block_deparallelize() + self.embed_tokens.to("cpu") + self.final_layer_norm.to("cpu") + self.model_parallel = False torch.cuda.empty_cache() def get_input_embeddings(self): @@ -828,9 +834,6 @@ def forward( output_hidden_states=None, return_dict=None, ): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -898,21 +901,6 @@ def forward( hidden_states = self.dropout(inputs_embeds) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that the layer_module args are on the same device as hidden_states - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) - if all_hidden_states is not None: - all_hidden_states = all_hidden_states.to(hidden_states.device) - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -947,13 +935,6 @@ def forward( if self.is_decoder: all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - devices = list(self.device_map.keys()) - for e, d in enumerate(devices): - if i == self.device_map[d][-1] and f"cuda:{d}" != self.last_device: - hidden_states = hidden_states.to(f"cuda:{devices[e+1]}") - hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1148,30 +1129,20 @@ def __init__(self, config: T5Config): self.init_weights() - # Model parallel self.model_parallel = False - self.device_map = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) + device_map = init_device_map(len(self.encoder.block), device_map) + self.encoder.parallelize(device_map) + self.decoder.parallelize(device_map) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): self.encoder.deparallelize() self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") self.model_parallel = False - self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): @@ -1252,18 +1223,6 @@ def forward( ) hidden_states = encoder_outputs[0] - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) # Decode decoder_outputs = self.decoder( @@ -1327,32 +1286,22 @@ def __init__(self, config): self.init_weights() - # Model parallel self.model_parallel = False - self.device_map = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.decoder.first_device) + device_map = init_device_map(len(self.encoder.block), device_map) + self.encoder.parallelize(device_map) + self.decoder.parallelize(device_map) + self.lm_head.to(self.decoder.first_device) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): self.encoder.deparallelize() self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.lm_head = self.lm_head.to("cpu") + self.lm_head.to("cpu") self.model_parallel = False - self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): @@ -1442,9 +1391,6 @@ def forward( hidden_states = encoder_outputs[0] - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) @@ -1458,17 +1404,6 @@ def forward( if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) - # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -1488,7 +1423,6 @@ def forward( # Set device for model parallelism if self.model_parallel: - torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) @@ -1582,23 +1516,18 @@ def __init__(self, config: T5Config): self.init_weights() + self.model_parallel = False + @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) + device_map = init_device_map(len(self.encoder.block), device_map) + self.encoder.parallelize(device_map) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): self.encoder.deparallelize() - self.encoder = self.encoder.to("cpu") self.model_parallel = False - self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py index 3a145df9868b..7d4b5f8933d0 100644 --- a/src/transformers/utils/model_parallel_utils.py +++ b/src/transformers/utils/model_parallel_utils.py @@ -15,8 +15,10 @@ from math import ceil +import torch -def assert_device_map(device_map, num_blocks): + +def validate_device_map(device_map, num_blocks): blocks = list(range(0, num_blocks)) device_map_blocks = [item for sublist in list(device_map.values()) for item in sublist] @@ -45,6 +47,75 @@ def assert_device_map(device_map, num_blocks): ) +def make_default_device_map(n_layers): + """Returns a dictionary of layers distributed evenly across all devices.""" + n_gpus = torch.cuda.device_count() + layers = list(range(n_layers)) + n_blocks = int(ceil(n_layers / n_gpus)) + layers_list = list(layers[i : i + n_blocks] for i in range(0, n_layers, n_blocks)) + + return dict(zip(range(n_gpus), layers_list)) + + +def init_device_map(n_layers, device_map=None): + """ + - creates a device_map if none was passed + - validates that map is correct + + Args: + n_layers - how many total layers to remap + """ + if device_map is None: + device_map = make_default_device_map(n_layers) + validate_device_map(device_map, n_layers) + return device_map + + +def model_parallel_inputs_to_device(func): + """ + This decorator will try to find a at least one parameter or a buffer to read layer's .device from and then will + automatically copy any inputs to that device before `forward` is called. + + this will work do its magical thing only if all params of this layer are on the same device + """ + + def _call__mp(self, *input, **kwargs): + + if not hasattr(self, "model_parallel") or not self.model_parallel: + return func(self, *input, **kwargs) + + # get device of any of the param of this layer + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + + # print(f"layer device: {device}") + if device is not None: + + # pprint(input) + input = list(input) + for i, v in enumerate(input): + if v is not None: + input[i] = v.to(device) + input = tuple(input) + # pprint(input) + + # pprint(kwargs) + for k in kwargs.keys(): + if kwargs[k] is not None and torch.is_tensor(kwargs[k]): + kwargs[k] = kwargs[k].to(device) + # pprint(kwargs) + + return func(self, *input, **kwargs) + + return _call__mp + + +# XXX: still used by gpt2 so leave here for now +assert_device_map = validate_device_map + + def get_device_map(n_layers, devices): """Returns a dictionary of layers distributed evenly across all devices.""" layers = list(range(n_layers)) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index bcf4e585fef8..9eca8e47cb8f 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -488,6 +488,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () + # T5EncoderModel too? all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True From c1b98e99d0b2dc724e1cd912b9d42a8d49e32982 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 27 Dec 2020 19:07:13 -0800 Subject: [PATCH 5/6] cleanup --- tests/test_modeling_t5.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 9eca8e47cb8f..bcf4e585fef8 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -488,7 +488,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () - # T5EncoderModel too? all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True From 58d047a596a97fbb815acb3e657102bf1960b06a Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 7 Jan 2021 12:17:56 -0800 Subject: [PATCH 6/6] fixes for generate tools + activate parallelize in trainer --- src/transformers/models/t5/modeling_t5.py | 30 +++++-- src/transformers/trainer.py | 8 ++ .../utils/model_parallel_utils.py | 86 +++++++++++++++---- 3 files changed, 100 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 57f48ed8f98c..25dbbf32b072 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -40,7 +40,11 @@ ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging -from ...utils.model_parallel_utils import init_device_map, model_parallel_inputs_to_device +from ...utils.model_parallel_utils import ( + init_device_map, + model_parallel_inputs_to_device, + model_parallel_inputs_to_specific_device, +) from .configuration_t5 import T5Config @@ -1142,12 +1146,14 @@ def __init__(self, config: T5Config): self.init_weights() self.model_parallel = False + self.main_device = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): device_map = init_device_map(len(self.encoder.block), device_map) self.encoder.parallelize(device_map) self.decoder.parallelize(device_map) + self.main_device = self.encoder.first_device self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) @@ -1254,6 +1260,11 @@ def forward( if not return_dict: return decoder_outputs + encoder_outputs + if self.model_parallel: + encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( + self.main_device, encoder_outputs, decoder_outputs + ) + return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, @@ -1300,13 +1311,17 @@ def __init__(self, config): self.init_weights() self.model_parallel = False + self.first_device = None + self.last_device = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): device_map = init_device_map(len(self.encoder.block), device_map) self.encoder.parallelize(device_map) self.decoder.parallelize(device_map) - self.lm_head.to(self.decoder.first_device) + self.first_device = self.encoder.first_device + self.last_device = self.decoder.last_device + self.lm_head.to(self.first_device) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) @@ -1432,12 +1447,13 @@ def forward( return_dict=return_dict, ) - sequence_output = decoder_outputs[0] - - # Set device for model parallelism if self.model_parallel: - self.lm_head = self.lm_head.to(self.encoder.first_device) - sequence_output = sequence_output.to(self.lm_head.weight.device) + encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( + self.first_device, encoder_outputs, decoder_outputs + ) + # self.lm_head.to(self.first_device) + + sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index effa50b5a92e..d1a964eaaf45 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -267,6 +267,14 @@ def __init__( ) self.model_init = model_init + if self.args.model_parallel: + if model.is_parallelizable: + model.parallelize() + else: + raise ValueError( + f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used" + ) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py index 7d4b5f8933d0..fedd0c3d3eb4 100644 --- a/src/transformers/utils/model_parallel_utils.py +++ b/src/transformers/utils/model_parallel_utils.py @@ -71,12 +71,51 @@ def init_device_map(n_layers, device_map=None): return device_map +def get_layer_device(self): + try: + device = next(self.parameters(recurse=True)).device + except StopIteration: + device = None + return device + + +def recursive_to(device, item): + """ + Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, + tuple and dict structures. + """ + + if torch.is_tensor(item): + return item.to(device) + + elif isinstance(item, list): + for i, x in enumerate(item): + item[i] = recursive_to(device, x) + return item + + elif isinstance(item, tuple): + return tuple(recursive_to(device, list(item))) + + elif isinstance(item, dict): + for k, v in item.items(): + item[k] = recursive_to(device, v) + return item + + else: + return item + + def model_parallel_inputs_to_device(func): """ - This decorator will try to find a at least one parameter or a buffer to read layer's .device from and then will - automatically copy any inputs to that device before `forward` is called. + This decorator is a noop unless self.model_parallel == True. - this will work do its magical thing only if all params of this layer are on the same device + It will try to find at least one parameter to read layer's .device from and then will automatically copy any inputs + to that device before `forward` is called. Use it as: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...) + + It will do its magical thing only if all params of this layer are on the same device. If it is not the case use + `model_parallel_inputs_to_specific_device` at the top of `forward` """ def _call__mp(self, *input, **kwargs): @@ -92,26 +131,39 @@ def _call__mp(self, *input, **kwargs): # print(f"layer device: {device}") if device is not None: + # torch.cuda.set_device(device) + # print(f"auto-move inputs to {device}") - # pprint(input) - input = list(input) - for i, v in enumerate(input): - if v is not None: - input[i] = v.to(device) - input = tuple(input) - # pprint(input) - - # pprint(kwargs) - for k in kwargs.keys(): - if kwargs[k] is not None and torch.is_tensor(kwargs[k]): - kwargs[k] = kwargs[k].to(device) - # pprint(kwargs) + input = recursive_to(device, input) + kwargs = recursive_to(device, kwargs) - return func(self, *input, **kwargs) + return func(self, *input, **kwargs) return _call__mp +def model_parallel_inputs_to_specific_device(device, *input): + """ + Similar to the model_parallel_inputs_to_device decorator, but this one is used for situations either when: 1. an + explicit call is desired (similar to `model.to()`) 2. the layer has params on mixed devices and therefore a wrong + device might get picked + + To use: + + @model_parallel_inputs_to_device def forward(self, input1, input2, ...): # get the desired device somewhere, e.g. a + specific param or a module attribute device = self.fc1.device input1, input2 = + model_parallel_inputs_to_specific_device(device, input1, input2) # this is the same as: input1 = input1.to(device) + input2 = input2.to(device) # but it works on variables that contain tensors but don't have `.to()` otherwise + """ + if device is None: + raise ValueError("device cannot be None") + # print(f"move specific inputs to {device}") + input = recursive_to(device, input) + # remove the need for the caller to perform "a, = foo(a)", + # which otherwise will make `a` a tuple when it might not be one + return input[0] if len(input) == 1 else input + + # XXX: still used by gpt2 so leave here for now assert_device_map = validate_device_map