-
Notifications
You must be signed in to change notification settings - Fork 33k
[t5 model parallel] misc fixes #9316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -785,17 +785,15 @@ def parallelize(self, device_map=None): | |
| ) | ||
| assert_device_map(self.device_map, len(self.block)) | ||
| self.model_parallel = True | ||
| self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) | ||
| self.last_device = "cuda:" + str(max(self.device_map.keys())) | ||
| self.first_device = "cpu" if "cpu" in self.device_map.keys() else f"cuda:{ list(self.device_map.keys())[0] }" | ||
| self.last_device = f"cuda:{ list(self.device_map.keys())[-1] }" | ||
| # Load onto devices | ||
| for k, v in self.device_map.items(): | ||
| for layer in v: | ||
| cuda_device = "cuda:" + str(k) | ||
| self.block[layer] = self.block[layer].to(cuda_device) | ||
|
|
||
| # Set embed_tokens to first layer | ||
| self.embed_tokens = self.embed_tokens.to(self.first_device) | ||
| # Set final layer norm to last device | ||
| self.final_layer_norm = self.final_layer_norm.to(self.last_device) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are loud and clear in the code already. |
||
|
|
||
| @add_start_docstrings(PARALLELIZE_DOCSTRING) | ||
|
|
@@ -833,7 +831,6 @@ def forward( | |
| # Model parallel | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.first_device) | ||
| self.embed_tokens = self.embed_tokens.to(self.first_device) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this has already been called in |
||
| use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
| output_hidden_states = ( | ||
|
|
@@ -904,9 +901,7 @@ def forward( | |
| # Model parallel | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(hidden_states.device) | ||
| # Ensure that attention_mask is always on the same device as hidden_states | ||
| if attention_mask is not None: | ||
| attention_mask = attention_mask.to(hidden_states.device) | ||
|
Comment on lines
-907
to
-909
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this variable is not needed to be switched to other devices |
||
| # Ensure that the layer_module args are on the same device as hidden_states | ||
| if position_bias is not None: | ||
| position_bias = position_bias.to(hidden_states.device) | ||
| if encoder_hidden_states is not None: | ||
|
|
@@ -915,6 +910,9 @@ def forward( | |
| encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) | ||
| if encoder_decoder_position_bias is not None: | ||
| encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) | ||
| if all_hidden_states is not None: | ||
| all_hidden_states = all_hidden_states.to(hidden_states.device) | ||
|
|
||
|
Comment on lines
+913
to
+915
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This on the other hand needs to be switched to the Moreover it's telling me, there is a test missing for this case. |
||
| if output_hidden_states: | ||
| all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
|
||
|
|
@@ -951,9 +949,10 @@ def forward( | |
|
|
||
| # Model Parallel: If it's the last layer for that device, put things on the next device | ||
| if self.model_parallel: | ||
| for k, v in self.device_map.items(): | ||
| if i == v[-1] and "cuda:" + str(k) != self.last_device: | ||
| hidden_states = hidden_states.to("cuda:" + str(k + 1)) | ||
|
Comment on lines
-954
to
-956
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and here we have an issue of an assumption that devices are not only ordered but also have a stride of 1, so the above would fail if we have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code is very hard to parse, that's why I think all these should be abstracted away, e.g. we should be able to say: something like that - and this whole thing can probably be abstracted into a single function anyway. Or perhaps since so
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This perhaps is more readable/efficient: this was an example of what could be a method of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other even simpler approach would be to create 2 reversed dicts for:
when you call and then it's just a one trivial lookup at the point where it's needed, with the special case of the last group of layers pointing to None for next_device_id, so then the code becomes: and no extra code needed at all.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and if we stick to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, if #9323 is welcomed, this whole discussion while still valid becomes moot, since none of this code is needed at all. |
||
| devices = list(self.device_map.keys()) | ||
| for e, d in enumerate(devices): | ||
| if i == self.device_map[d][-1] and f"cuda:{d}" != self.last_device: | ||
| hidden_states = hidden_states.to(f"cuda:{devices[e+1]}") | ||
|
|
||
| hidden_states = self.final_layer_norm(hidden_states) | ||
| hidden_states = self.dropout(hidden_states) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this change deals with the case where devices aren't necessarily ordered by ID. That is the case where a device_map's keys are say,
[0,2,1,3]for whatever reason.The only issue here is that only starting from py37 we get ordered dicts and all is good, so for py36, unless it is a cython-built we have an issue that 0th item might not be the 0th item in the device map. Not sure how to deal with it.
Or alternatively we need to assert if the order is not numericaly rising, but again we have an issue with the dict, before py37.
Perhaps the simplest thing is to say that MP requires py > 3.6
Or the user needs to build the
device_mapusingcollections.OrderedDict.We could check in the
device_mapvalidation function whether the passeddevice_mapis eithercollections.OrderedDictor py > 3.6 or assert otherwise.