Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,17 +785,15 @@ def parallelize(self, device_map=None):
)
assert_device_map(self.device_map, len(self.block))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
self.first_device = "cpu" if "cpu" in self.device_map.keys() else f"cuda:{ list(self.device_map.keys())[0] }"
self.last_device = f"cuda:{ list(self.device_map.keys())[-1] }"
Comment on lines +788 to +789
Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this change deals with the case where devices aren't necessarily ordered by ID. That is the case where a device_map's keys are say, [0,2,1,3] for whatever reason.

The only issue here is that only starting from py37 we get ordered dicts and all is good, so for py36, unless it is a cython-built we have an issue that 0th item might not be the 0th item in the device map. Not sure how to deal with it.

Or alternatively we need to assert if the order is not numericaly rising, but again we have an issue with the dict, before py37.

Perhaps the simplest thing is to say that MP requires py > 3.6

Or the user needs to build the device_map using collections.OrderedDict.

We could check in the device_map validation function whether the passed device_map is either collections.OrderedDict or py > 3.6 or assert otherwise.

# Load onto devices
for k, v in self.device_map.items():
for layer in v:
cuda_device = "cuda:" + str(k)
self.block[layer] = self.block[layer].to(cuda_device)

# Set embed_tokens to first layer
self.embed_tokens = self.embed_tokens.to(self.first_device)
# Set final layer norm to last device
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are loud and clear in the code already.


@add_start_docstrings(PARALLELIZE_DOCSTRING)
Expand Down Expand Up @@ -833,7 +831,6 @@ def forward(
# Model parallel
if self.model_parallel:
torch.cuda.set_device(self.first_device)
self.embed_tokens = self.embed_tokens.to(self.first_device)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has already been called in parallelize - did I miss that perhaps it may have changed since parallelize or would it be better to call it here rather than in parallelize?

use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -904,9 +901,7 @@ def forward(
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
Comment on lines -907 to -909
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable is not needed to be switched to other devices

# Ensure that the layer_module args are on the same device as hidden_states
if position_bias is not None:
position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None:
Expand All @@ -915,6 +910,9 @@ def forward(
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if all_hidden_states is not None:
all_hidden_states = all_hidden_states.to(hidden_states.device)

Comment on lines +913 to +915
Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This on the other hand needs to be switched to the hidden_states.device. Otherwise we end up with the next line where 2 vars being summed - could be on 2 different devices.

Moreover it's telling me, there is a test missing for this case.

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

Expand Down Expand Up @@ -951,9 +949,10 @@ def forward(

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
Comment on lines -954 to -956
Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here we have an issue of an assumption that devices are not only ordered but also have a stride of 1, so the above would fail if we have [0,2,1,3], because k+1 for when k=0, will be k==1, but it should be k==2.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code is very hard to parse, that's why I think all these should be abstracted away, e.g. we should be able to say:

if self.model_parallel: 
    current_device = self.device_map.layer_to_device(i)
    if self.device_map.is_this_layer_last_for_this_device(i) and self.device_map.device_not_last(current_device):
         hidden_states = hidden_states.to(self.device_map.next_device(current_device))

something like that - and this whole thing can probably be abstracted into a single function anyway.

Or perhaps since to is close to free if it's the same device. The logic can be just:

if self.model_parallel:
    hidden_states = hidden_states.to(self.device_map.next_device(i))

so next_device simply returns the same device if we aren't done with this device, otherwise it'll return the next device.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works:

            def next_device(device_map, current_layer_number):
                devices = list(device_map.keys())
                for i, d in enumerate(devices):
                    if current_layer_number == device_map[d][-1]:
                        if i+1 != len(devices):
                            return devices[i+1]
                        else:
                            return devices[i]
                    
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                hidden_states = hidden_states.to(next_device(self.device_map, i))

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This perhaps is more readable/efficient:

            def next_device_if_any(device_map, current_layer_number):
                """ returns False if shouldn't switch, otherwise device id to switch to """
                devices = list(device_map.keys())
                for i, d in enumerate(devices):
                    if current_layer_number == device_map[d][-1] and i+1 != len(devices):
                        return devices[i+1] # next device
                        #return f"cuda:{devices[i+1]}" # works w/ or w/ cuda:
                return False
                    
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                next_device_id = next_device_if_any(self.device_map, i)
                if next_device_id is not False:
                    hidden_states = hidden_states.to(next_device_id)

this was an example of what could be a method of self.device_map object. which makes the code much easier to read.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other even simpler approach would be to create 2 reversed dicts for:

  1. layer number -> device_id
  2. layer number -> next_device_id

when you call model.parallelize

and then it's just a one trivial lookup at the point where it's needed, with the special case of the last group of layers pointing to None for next_device_id, so then the code becomes:

if self.model_parallel:
    if self.device_map_layer_to_next_device_id[i] is not None:
        hidden_states = hidden_states.to(self.device_map_layer_to_next_device_id[i])

and no extra code needed at all.

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if we stick to cuda:X we should do all that remapping at the beginning too and not have the mess of adding cuda: everywhere. And it's easier to compare foo.device too that way.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if #9323 is welcomed, this whole discussion while still valid becomes moot, since none of this code is needed at all.

devices = list(self.device_map.keys())
for e, d in enumerate(devices):
if i == self.device_map[d][-1] and f"cuda:{d}" != self.last_device:
hidden_states = hidden_states.to(f"cuda:{devices[e+1]}")

hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
Expand Down