Skip to content

[t5 model parallel] misc fixes#9316

Closed
stas00 wants to merge 3 commits intohuggingface:masterfrom
stas00:t5-mp
Closed

[t5 model parallel] misc fixes#9316
stas00 wants to merge 3 commits intohuggingface:masterfrom
stas00:t5-mp

Conversation

@stas00
Copy link
Copy Markdown
Contributor

@stas00 stas00 commented Dec 27, 2020

This PR:

  • in 2 places fixes an assumption that devices on the device map are always (0, 1, 2, 3) and:
    1. are ordered by their cuda device id and not (2, 3, 0, 1)
    2. have a stride of 1 and not (0, 2, 3, 5)
  • adds a missing to(), removes a redundant to()
  • removes obvious comments
  • removes code that gets run twice
  • this PR continues at [T5 model parallel] implement input auto-relocation + lots of refactoring/cleanup #9323 - I branched off from this PR and implemented an automatic remap of inputs and a lot refactoring.

I will comment on the reasons for changes in the code.

There is one gotcha wrt py36 w/o cython not having its dict ordered. Please see #9316 (comment)

I think sorting out the logic first device/last device/is_this_the_last_layer_of_this_device and such logic should be abstracted away for readability, and not needing to replicate the same logic in each model. Perhaps self.device_map should be a smart class that can provide all the answers via its methods.

@alexorona, I'm studying your t5-mp implementation to do the same for bart. Thank you for doing the hard work of putting the foundation in place and porting 2 models!!!

Please have a look and let me know if my tweaks make sense. Your original code is excellent - I'm just trying to think how to make it easier to replicate it in other models and improve readability, hence a gazillion of questions/suggestions.

Also, if you don't mind I have a few design questions:

  1. Could you please comment on why you are splitting the encoder between all devices on the device map and the same for the decoder? Won't it be more efficient performance-wise to put the encoder on the first group of devices and decoder on the second?

  2. I also find it confusing that the device map doesn't map out the whole model, but just the encoder and assumes that the decoder has the same config. I'm not familiar with t5 but other models definitely can have encoder and decoder that don't at all match number of layers-wise. And while this is perhaps not the case for t5, I think the device map should be intuitively similar for all models as we slowly progress with porting other models to MP. That is I think it should include all layers of the model and not half of them.

@patrickvonplaten, @LysandreJik

Comment on lines +788 to +789
self.first_device = "cpu" if "cpu" in self.device_map.keys() else f"cuda:{ list(self.device_map.keys())[0] }"
self.last_device = f"cuda:{ list(self.device_map.keys())[-1] }"
Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this change deals with the case where devices aren't necessarily ordered by ID. That is the case where a device_map's keys are say, [0,2,1,3] for whatever reason.

The only issue here is that only starting from py37 we get ordered dicts and all is good, so for py36, unless it is a cython-built we have an issue that 0th item might not be the 0th item in the device map. Not sure how to deal with it.

Or alternatively we need to assert if the order is not numericaly rising, but again we have an issue with the dict, before py37.

Perhaps the simplest thing is to say that MP requires py > 3.6

Or the user needs to build the device_map using collections.OrderedDict.

We could check in the device_map validation function whether the passed device_map is either collections.OrderedDict or py > 3.6 or assert otherwise.

# Model parallel
if self.model_parallel:
torch.cuda.set_device(self.first_device)
self.embed_tokens = self.embed_tokens.to(self.first_device)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has already been called in parallelize - did I miss that perhaps it may have changed since parallelize or would it be better to call it here rather than in parallelize?

Comment on lines -954 to -956
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here we have an issue of an assumption that devices are not only ordered but also have a stride of 1, so the above would fail if we have [0,2,1,3], because k+1 for when k=0, will be k==1, but it should be k==2.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code is very hard to parse, that's why I think all these should be abstracted away, e.g. we should be able to say:

if self.model_parallel: 
    current_device = self.device_map.layer_to_device(i)
    if self.device_map.is_this_layer_last_for_this_device(i) and self.device_map.device_not_last(current_device):
         hidden_states = hidden_states.to(self.device_map.next_device(current_device))

something like that - and this whole thing can probably be abstracted into a single function anyway.

Or perhaps since to is close to free if it's the same device. The logic can be just:

if self.model_parallel:
    hidden_states = hidden_states.to(self.device_map.next_device(i))

so next_device simply returns the same device if we aren't done with this device, otherwise it'll return the next device.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works:

            def next_device(device_map, current_layer_number):
                devices = list(device_map.keys())
                for i, d in enumerate(devices):
                    if current_layer_number == device_map[d][-1]:
                        if i+1 != len(devices):
                            return devices[i+1]
                        else:
                            return devices[i]
                    
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                hidden_states = hidden_states.to(next_device(self.device_map, i))

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This perhaps is more readable/efficient:

            def next_device_if_any(device_map, current_layer_number):
                """ returns False if shouldn't switch, otherwise device id to switch to """
                devices = list(device_map.keys())
                for i, d in enumerate(devices):
                    if current_layer_number == device_map[d][-1] and i+1 != len(devices):
                        return devices[i+1] # next device
                        #return f"cuda:{devices[i+1]}" # works w/ or w/ cuda:
                return False
                    
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                next_device_id = next_device_if_any(self.device_map, i)
                if next_device_id is not False:
                    hidden_states = hidden_states.to(next_device_id)

this was an example of what could be a method of self.device_map object. which makes the code much easier to read.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other even simpler approach would be to create 2 reversed dicts for:

  1. layer number -> device_id
  2. layer number -> next_device_id

when you call model.parallelize

and then it's just a one trivial lookup at the point where it's needed, with the special case of the last group of layers pointing to None for next_device_id, so then the code becomes:

if self.model_parallel:
    if self.device_map_layer_to_next_device_id[i] is not None:
        hidden_states = hidden_states.to(self.device_map_layer_to_next_device_id[i])

and no extra code needed at all.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if we stick to cuda:X we should do all that remapping at the beginning too and not have the mess of adding cuda: everywhere. And it's easier to compare foo.device too that way.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if #9323 is welcomed, this whole discussion while still valid becomes moot, since none of this code is needed at all.

@stas00 stas00 changed the title [t5 model parallel] adjustment for an assumption that devices are [t5 model parallel] devices might not be sequential with stride 1 Dec 27, 2020
# Set embed_tokens to first layer
self.embed_tokens = self.embed_tokens.to(self.first_device)
# Set final layer norm to last device
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are loud and clear in the code already.

Comment on lines -907 to -909
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable is not needed to be switched to other devices

Comment on lines +913 to +915
if all_hidden_states is not None:
all_hidden_states = all_hidden_states.to(hidden_states.device)

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This on the other hand needs to be switched to the hidden_states.device. Otherwise we end up with the next line where 2 vars being summed - could be on 2 different devices.

Moreover it's telling me, there is a test missing for this case.

@stas00 stas00 changed the title [t5 model parallel] devices might not be sequential with stride 1 [t5 model parallel] misc fixes Dec 27, 2020
@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Dec 27, 2020

Somehow I'm feeling that this approach of having special logic to remap inputs ahead of time is over-complicated. I haven't tried it yet, but won't it be much much simpler to remap inputs once the model layer is visible and just before they are used by that layer - i.e. at the point where one gets:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Then we just do input_foo = input_foo.to(next(self.parameters()).device) and we are done? No logic required other than just put the inputs on the same device as this layer.

We might be able to even remove all those if self.model_parallel in most places in forward, and have the same code for w/ or w/o MP. Perhaps with some wrapper that will be noop when not under MP. It could also handle None to avoid a gazillion of if not None checks. I'd also make it an in-place operation, just like nn.Module.to does.

edit I branched off from this PR and implemented this - works amazingly simple: #9323

@patrickvonplaten
Copy link
Copy Markdown
Contributor

I also think #9323 is the way to go here

@stas00 stas00 added the Model Parallel Model Parallelilsm Implementations label Jan 2, 2021
@huggingface huggingface deleted a comment from github-actions bot Apr 15, 2021
@stas00 stas00 self-assigned this Apr 15, 2021
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Apr 15, 2021
@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jun 4, 2021

too long. closing.

@stas00 stas00 closed this Jun 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Model Parallel Model Parallelilsm Implementations WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants