-
Notifications
You must be signed in to change notification settings - Fork 33k
[T5 model parallel] implement input auto-relocation + lots of refactoring/cleanup #9323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3620a52
07bf1ac
5851730
da46cac
c1b98e9
ab05fb6
58d047a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,11 @@ | |
| ) | ||
| from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer | ||
| from ...utils import logging | ||
| from ...utils.model_parallel_utils import assert_device_map, get_device_map | ||
| from ...utils.model_parallel_utils import ( | ||
| init_device_map, | ||
| model_parallel_inputs_to_device, | ||
| model_parallel_inputs_to_specific_device, | ||
| ) | ||
| from .configuration_t5 import T5Config | ||
|
|
||
|
|
||
|
|
@@ -597,6 +601,17 @@ def __init__(self, config, has_relative_attention_bias=False): | |
|
|
||
| self.layer.append(T5LayerFF(config)) | ||
|
|
||
| self.model_parallel = False | ||
|
|
||
| def block_parallelize(self, device): | ||
| self.to(device) | ||
| self.model_parallel = True | ||
|
|
||
| def block_deparallelize(self): | ||
| self.to("cpu") | ||
| self.model_parallel = False | ||
|
|
||
| @model_parallel_inputs_to_device | ||
| def forward( | ||
| self, | ||
| hidden_states, | ||
|
|
@@ -776,41 +791,34 @@ def __init__(self, config, embed_tokens=None): | |
| self.dropout = nn.Dropout(config.dropout_rate) | ||
|
|
||
| self.init_weights() | ||
| # Model parallel | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't something like: if self.model_parallel:
hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value = _call__mp(hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value)directly under the function signature work? But I agree that it's a lot of boilerplate code ... -> maybe the best way is indeed to use a decorator function here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see - you are suggesting to revert it to the original way where each input was manually switched to its destination device, albeit at the point of use. There is a definite benefit to your suggestion that it's explicit, rather than magical, behind-the-scenes switch. My initial reaction was Meh, but sitting more with it, it bodes well with all the other The more I'm thinking about it, the more I think future pytorch might find a way to not need to manually move inputs to devices and not just with MP. It should figure it all out based on params which already know which devices they are on. And then it will be all magical.
In case it wasn't clear, my follow up suggested not to use a decorator since you decided not to use it, but an even more efficient way, which would only impact MP code Let's see what @LysandreJik thinks.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for waiting for my input. I would prefer to go down the explicit road rather than the magical one. Patrick's proposition seems like the most understandable to me, while keeping all the features you proposed @stas00. The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. explicit calls it is then, works for me. wrt |
||
| self.model_parallel = False | ||
| self.device_map = None | ||
| self.first_device = "cpu" | ||
| self.last_device = "cpu" | ||
|
|
||
| @add_start_docstrings(PARALLELIZE_DOCSTRING) | ||
| def parallelize(self, device_map=None): | ||
| # Check validity of device_map | ||
| self.device_map = ( | ||
| get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map | ||
| ) | ||
| assert_device_map(self.device_map, len(self.block)) | ||
| self.model_parallel = True | ||
| self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) | ||
| self.last_device = "cuda:" + str(max(self.device_map.keys())) | ||
| device_map = init_device_map(len(self.block), device_map) | ||
| self.first_device = f"cuda:{ list(device_map.keys())[0] }" | ||
| self.last_device = f"cuda:{ list(device_map.keys())[-1] }" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should instead put some additional work into And keep the old language on
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not even sure we even need first/last devices. But yes, I think eventually each model will have to implement its own default device map. see the work I started in #9384 So there will be no need to infer, since the model will just have a method that will do that. Of course, we don't have to re-implement it for each model if it's the similar archs, so perhaps we will end up with a set of default maps and each model will pick and customize the one it needs. e.g. t5 and bart are almost the same with the exception of the name of the module list of the layers/blocks. wrt to some layers on that's actually a cool thing since then one could play with MP with just one gpu.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No strong opposition to update from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not a rename, it's a change in functionality - where the default and check are merged into one function to avoid pointless code repetition - doesn't contribute to easier understanding. and yes, point taken on supporting cpu. I propose the first stage is to polish the API for gpus-only and then definitely include cpu. as it'd be helpful for developing/debugging MP with just one gpu. I will start making todo items. |
||
| # Load onto devices | ||
| for k, v in self.device_map.items(): | ||
| for k, v in device_map.items(): | ||
| for layer in v: | ||
| cuda_device = "cuda:" + str(k) | ||
| self.block[layer] = self.block[layer].to(cuda_device) | ||
|
|
||
| # Set embed_tokens to first layer | ||
| self.embed_tokens = self.embed_tokens.to(self.first_device) | ||
| # Set final layer norm to last device | ||
| self.final_layer_norm = self.final_layer_norm.to(self.last_device) | ||
| self.block[layer].block_parallelize(f"cuda:{k}") | ||
| self.embed_tokens.to(self.first_device) | ||
| self.final_layer_norm.to(self.last_device) | ||
| self.model_parallel = True | ||
|
|
||
| @add_start_docstrings(PARALLELIZE_DOCSTRING) | ||
| @add_start_docstrings(DEPARALLELIZE_DOCSTRING) | ||
| def deparallelize(self): | ||
| self.model_parallel = False | ||
| self.device_map = None | ||
| self.to("cpu") | ||
| self.first_device = "cpu" | ||
| self.last_device = "cpu" | ||
| for i in range(len(self.block)): | ||
| self.block[i] = self.block[i].to("cpu") | ||
| self.embed_tokens = self.embed_tokens.to("cpu") | ||
| self.final_layer_norm = self.final_layer_norm.to("cpu") | ||
| self.block[i].block_deparallelize() | ||
| self.embed_tokens.to("cpu") | ||
| self.final_layer_norm.to("cpu") | ||
| self.model_parallel = False | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def get_input_embeddings(self): | ||
|
|
@@ -833,10 +841,6 @@ def forward( | |
| output_hidden_states=None, | ||
| return_dict=None, | ||
| ): | ||
| # Model parallel | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.first_device) | ||
| self.embed_tokens = self.embed_tokens.to(self.first_device) | ||
| use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
| output_hidden_states = ( | ||
|
|
@@ -904,20 +908,6 @@ def forward( | |
| hidden_states = self.dropout(inputs_embeds) | ||
|
|
||
| for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): | ||
| # Model parallel | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(hidden_states.device) | ||
| # Ensure that attention_mask is always on the same device as hidden_states | ||
| if attention_mask is not None: | ||
| attention_mask = attention_mask.to(hidden_states.device) | ||
| if position_bias is not None: | ||
| position_bias = position_bias.to(hidden_states.device) | ||
| if encoder_hidden_states is not None: | ||
| encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) | ||
| if encoder_extended_attention_mask is not None: | ||
| encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) | ||
| if encoder_decoder_position_bias is not None: | ||
| encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) | ||
| if output_hidden_states: | ||
| all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
|
||
|
|
@@ -952,12 +942,6 @@ def forward( | |
| if self.is_decoder: | ||
| all_cross_attentions = all_cross_attentions + (layer_outputs[5],) | ||
|
|
||
| # Model Parallel: If it's the last layer for that device, put things on the next device | ||
| if self.model_parallel: | ||
| for k, v in self.device_map.items(): | ||
| if i == v[-1] and "cuda:" + str(k) != self.last_device: | ||
| hidden_states = hidden_states.to("cuda:" + str(k + 1)) | ||
|
|
||
| hidden_states = self.final_layer_norm(hidden_states) | ||
| hidden_states = self.dropout(hidden_states) | ||
|
|
||
|
|
@@ -1161,30 +1145,22 @@ def __init__(self, config: T5Config): | |
|
|
||
| self.init_weights() | ||
|
|
||
| # Model parallel | ||
| self.model_parallel = False | ||
| self.device_map = None | ||
| self.main_device = None | ||
|
|
||
| @add_start_docstrings(PARALLELIZE_DOCSTRING) | ||
| def parallelize(self, device_map=None): | ||
| self.device_map = ( | ||
| get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | ||
| if device_map is None | ||
| else device_map | ||
| ) | ||
| assert_device_map(self.device_map, len(self.encoder.block)) | ||
| self.encoder.parallelize(self.device_map) | ||
| self.decoder.parallelize(self.device_map) | ||
| device_map = init_device_map(len(self.encoder.block), device_map) | ||
| self.encoder.parallelize(device_map) | ||
| self.decoder.parallelize(device_map) | ||
| self.main_device = self.encoder.first_device | ||
| self.model_parallel = True | ||
|
|
||
| @add_start_docstrings(DEPARALLELIZE_DOCSTRING) | ||
| def deparallelize(self): | ||
| self.encoder.deparallelize() | ||
| self.decoder.deparallelize() | ||
| self.encoder = self.encoder.to("cpu") | ||
| self.decoder = self.decoder.to("cpu") | ||
| self.model_parallel = False | ||
| self.device_map = None | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def get_input_embeddings(self): | ||
|
|
@@ -1265,18 +1241,6 @@ def forward( | |
| ) | ||
|
|
||
| hidden_states = encoder_outputs[0] | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.decoder.first_device) | ||
| # Set device for model parallelism | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.decoder.first_device) | ||
| hidden_states = hidden_states.to(self.decoder.first_device) | ||
| if decoder_input_ids is not None: | ||
| decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) | ||
| if attention_mask is not None: | ||
| attention_mask = attention_mask.to(self.decoder.first_device) | ||
| if decoder_attention_mask is not None: | ||
| decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) | ||
|
|
||
| # Decode | ||
| decoder_outputs = self.decoder( | ||
|
|
@@ -1296,6 +1260,11 @@ def forward( | |
| if not return_dict: | ||
| return decoder_outputs + encoder_outputs | ||
|
|
||
| if self.model_parallel: | ||
| encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( | ||
| self.main_device, encoder_outputs, decoder_outputs | ||
| ) | ||
|
|
||
| return Seq2SeqModelOutput( | ||
| last_hidden_state=decoder_outputs.last_hidden_state, | ||
| past_key_values=decoder_outputs.past_key_values, | ||
|
|
@@ -1341,32 +1310,26 @@ def __init__(self, config): | |
|
|
||
| self.init_weights() | ||
|
|
||
| # Model parallel | ||
| self.model_parallel = False | ||
| self.device_map = None | ||
| self.first_device = None | ||
| self.last_device = None | ||
|
|
||
| @add_start_docstrings(PARALLELIZE_DOCSTRING) | ||
| def parallelize(self, device_map=None): | ||
| self.device_map = ( | ||
| get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | ||
| if device_map is None | ||
| else device_map | ||
| ) | ||
| assert_device_map(self.device_map, len(self.encoder.block)) | ||
| self.encoder.parallelize(self.device_map) | ||
| self.decoder.parallelize(self.device_map) | ||
| self.lm_head = self.lm_head.to(self.decoder.first_device) | ||
| device_map = init_device_map(len(self.encoder.block), device_map) | ||
| self.encoder.parallelize(device_map) | ||
| self.decoder.parallelize(device_map) | ||
| self.first_device = self.encoder.first_device | ||
| self.last_device = self.decoder.last_device | ||
| self.lm_head.to(self.first_device) | ||
| self.model_parallel = True | ||
|
|
||
| @add_start_docstrings(DEPARALLELIZE_DOCSTRING) | ||
| def deparallelize(self): | ||
| self.encoder.deparallelize() | ||
| self.decoder.deparallelize() | ||
| self.encoder = self.encoder.to("cpu") | ||
| self.decoder = self.decoder.to("cpu") | ||
| self.lm_head = self.lm_head.to("cpu") | ||
| self.lm_head.to("cpu") | ||
| self.model_parallel = False | ||
| self.device_map = None | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def get_input_embeddings(self): | ||
|
|
@@ -1456,9 +1419,6 @@ def forward( | |
|
|
||
| hidden_states = encoder_outputs[0] | ||
|
|
||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.decoder.first_device) | ||
|
|
||
| if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: | ||
| # get decoder inputs from shifting lm labels to the right | ||
| decoder_input_ids = self._shift_right(labels) | ||
|
|
@@ -1472,17 +1432,6 @@ def forward( | |
| if decoder_inputs_embeds is not None: | ||
| decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] | ||
|
|
||
| # Set device for model parallelism | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.decoder.first_device) | ||
| hidden_states = hidden_states.to(self.decoder.first_device) | ||
| if decoder_input_ids is not None: | ||
| decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) | ||
| if attention_mask is not None: | ||
| attention_mask = attention_mask.to(self.decoder.first_device) | ||
| if decoder_attention_mask is not None: | ||
| decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) | ||
|
|
||
| # Decode | ||
| decoder_outputs = self.decoder( | ||
| input_ids=decoder_input_ids, | ||
|
|
@@ -1498,13 +1447,13 @@ def forward( | |
| return_dict=return_dict, | ||
| ) | ||
|
|
||
| sequence_output = decoder_outputs[0] | ||
|
|
||
| # Set device for model parallelism | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.encoder.first_device) | ||
| self.lm_head = self.lm_head.to(self.encoder.first_device) | ||
| sequence_output = sequence_output.to(self.lm_head.weight.device) | ||
| encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device( | ||
| self.first_device, encoder_outputs, decoder_outputs | ||
| ) | ||
| # self.lm_head.to(self.first_device) | ||
|
|
||
| sequence_output = decoder_outputs[0] | ||
|
|
||
| if self.config.tie_word_embeddings: | ||
| # Rescale output before projecting on vocab | ||
|
|
@@ -1596,23 +1545,18 @@ def __init__(self, config: T5Config): | |
|
|
||
| self.init_weights() | ||
|
|
||
| self.model_parallel = False | ||
|
|
||
| @add_start_docstrings(PARALLELIZE_DOCSTRING) | ||
| def parallelize(self, device_map=None): | ||
| self.device_map = ( | ||
| get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) | ||
| if device_map is None | ||
| else device_map | ||
| ) | ||
| assert_device_map(self.device_map, len(self.encoder.block)) | ||
| self.encoder.parallelize(self.device_map) | ||
| device_map = init_device_map(len(self.encoder.block), device_map) | ||
| self.encoder.parallelize(device_map) | ||
| self.model_parallel = True | ||
|
|
||
| @add_start_docstrings(DEPARALLELIZE_DOCSTRING) | ||
| def deparallelize(self): | ||
| self.encoder.deparallelize() | ||
| self.encoder = self.encoder.to("cpu") | ||
| self.model_parallel = False | ||
| self.device_map = None | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def get_input_embeddings(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need
model_parallelas an attribute on the module itself?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, if we use the automatic input remapper, since there is no other way to do
if self.model_parallelbut we surely will eventually remove anything that is not being used.Actually do we really need
block_deparallelize? What would be the practical use?