Conversation
| self.first_device = "cpu" if "cpu" in self.device_map.keys() else f"cuda:{ list(self.device_map.keys())[0] }" | ||
| self.last_device = f"cuda:{ list(self.device_map.keys())[-1] }" |
There was a problem hiding this comment.
So this change deals with the case where devices aren't necessarily ordered by ID. That is the case where a device_map's keys are say, [0,2,1,3] for whatever reason.
The only issue here is that only starting from py37 we get ordered dicts and all is good, so for py36, unless it is a cython-built we have an issue that 0th item might not be the 0th item in the device map. Not sure how to deal with it.
Or alternatively we need to assert if the order is not numericaly rising, but again we have an issue with the dict, before py37.
Perhaps the simplest thing is to say that MP requires py > 3.6
Or the user needs to build the device_map using collections.OrderedDict.
We could check in the device_map validation function whether the passed device_map is either collections.OrderedDict or py > 3.6 or assert otherwise.
| # Model parallel | ||
| if self.model_parallel: | ||
| torch.cuda.set_device(self.first_device) | ||
| self.embed_tokens = self.embed_tokens.to(self.first_device) |
There was a problem hiding this comment.
this has already been called in parallelize - did I miss that perhaps it may have changed since parallelize or would it be better to call it here rather than in parallelize?
| for k, v in self.device_map.items(): | ||
| if i == v[-1] and "cuda:" + str(k) != self.last_device: | ||
| hidden_states = hidden_states.to("cuda:" + str(k + 1)) |
There was a problem hiding this comment.
and here we have an issue of an assumption that devices are not only ordered but also have a stride of 1, so the above would fail if we have [0,2,1,3], because k+1 for when k=0, will be k==1, but it should be k==2.
There was a problem hiding this comment.
the code is very hard to parse, that's why I think all these should be abstracted away, e.g. we should be able to say:
if self.model_parallel:
current_device = self.device_map.layer_to_device(i)
if self.device_map.is_this_layer_last_for_this_device(i) and self.device_map.device_not_last(current_device):
hidden_states = hidden_states.to(self.device_map.next_device(current_device))
something like that - and this whole thing can probably be abstracted into a single function anyway.
Or perhaps since to is close to free if it's the same device. The logic can be just:
if self.model_parallel:
hidden_states = hidden_states.to(self.device_map.next_device(i))
so next_device simply returns the same device if we aren't done with this device, otherwise it'll return the next device.
There was a problem hiding this comment.
This works:
def next_device(device_map, current_layer_number):
devices = list(device_map.keys())
for i, d in enumerate(devices):
if current_layer_number == device_map[d][-1]:
if i+1 != len(devices):
return devices[i+1]
else:
return devices[i]
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
hidden_states = hidden_states.to(next_device(self.device_map, i))
There was a problem hiding this comment.
This perhaps is more readable/efficient:
def next_device_if_any(device_map, current_layer_number):
""" returns False if shouldn't switch, otherwise device id to switch to """
devices = list(device_map.keys())
for i, d in enumerate(devices):
if current_layer_number == device_map[d][-1] and i+1 != len(devices):
return devices[i+1] # next device
#return f"cuda:{devices[i+1]}" # works w/ or w/ cuda:
return False
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
next_device_id = next_device_if_any(self.device_map, i)
if next_device_id is not False:
hidden_states = hidden_states.to(next_device_id)
this was an example of what could be a method of self.device_map object. which makes the code much easier to read.
There was a problem hiding this comment.
The other even simpler approach would be to create 2 reversed dicts for:
- layer number -> device_id
- layer number -> next_device_id
when you call model.parallelize
and then it's just a one trivial lookup at the point where it's needed, with the special case of the last group of layers pointing to None for next_device_id, so then the code becomes:
if self.model_parallel:
if self.device_map_layer_to_next_device_id[i] is not None:
hidden_states = hidden_states.to(self.device_map_layer_to_next_device_id[i])
and no extra code needed at all.
There was a problem hiding this comment.
and if we stick to cuda:X we should do all that remapping at the beginning too and not have the mess of adding cuda: everywhere. And it's easier to compare foo.device too that way.
There was a problem hiding this comment.
Well, if #9323 is welcomed, this whole discussion while still valid becomes moot, since none of this code is needed at all.
| # Set embed_tokens to first layer | ||
| self.embed_tokens = self.embed_tokens.to(self.first_device) | ||
| # Set final layer norm to last device | ||
| self.final_layer_norm = self.final_layer_norm.to(self.last_device) |
There was a problem hiding this comment.
these are loud and clear in the code already.
| # Ensure that attention_mask is always on the same device as hidden_states | ||
| if attention_mask is not None: | ||
| attention_mask = attention_mask.to(hidden_states.device) |
There was a problem hiding this comment.
this variable is not needed to be switched to other devices
| if all_hidden_states is not None: | ||
| all_hidden_states = all_hidden_states.to(hidden_states.device) | ||
|
|
There was a problem hiding this comment.
This on the other hand needs to be switched to the hidden_states.device. Otherwise we end up with the next line where 2 vars being summed - could be on 2 different devices.
Moreover it's telling me, there is a test missing for this case.
|
Somehow I'm feeling that this approach of having special logic to remap inputs ahead of time is over-complicated. I haven't tried it yet, but won't it be much much simpler to remap inputs once the model layer is visible and just before they are used by that layer - i.e. at the point where one gets: Then we just do We might be able to even remove all those edit I branched off from this PR and implemented this - works amazingly simple: #9323 |
|
I also think #9323 is the way to go here |
|
too long. closing. |
This PR:
(0, 1, 2, 3)and:(2, 3, 0, 1)(0, 2, 3, 5)to(), removes a redundantto()I will comment on the reasons for changes in the code.
There is one gotcha wrt py36 w/o cython not having its dict ordered. Please see #9316 (comment)
I think sorting out the logic first device/last device/is_this_the_last_layer_of_this_device and such logic should be abstracted away for readability, and not needing to replicate the same logic in each model. Perhaps
self.device_mapshould be a smart class that can provide all the answers via its methods.@alexorona, I'm studying your t5-mp implementation to do the same for bart. Thank you for doing the hard work of putting the foundation in place and porting 2 models!!!
Please have a look and let me know if my tweaks make sense. Your original code is excellent - I'm just trying to think how to make it easier to replicate it in other models and improve readability, hence a gazillion of questions/suggestions.
Also, if you don't mind I have a few design questions:
Could you please comment on why you are splitting the encoder between all devices on the device map and the same for the decoder? Won't it be more efficient performance-wise to put the encoder on the first group of devices and decoder on the second?
I also find it confusing that the device map doesn't map out the whole model, but just the encoder and assumes that the decoder has the same config. I'm not familiar with t5 but other models definitely can have encoder and decoder that don't at all match number of layers-wise. And while this is perhaps not the case for t5, I think the device map should be intuitively similar for all models as we slowly progress with porting other models to MP. That is I think it should include all layers of the model and not half of them.
@patrickvonplaten, @LysandreJik