Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 62 additions & 118 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from ...utils.model_parallel_utils import (
init_device_map,
model_parallel_inputs_to_device,
model_parallel_inputs_to_specific_device,
)
from .configuration_t5 import T5Config


Expand Down Expand Up @@ -597,6 +601,17 @@ def __init__(self, config, has_relative_attention_bias=False):

self.layer.append(T5LayerFF(config))

self.model_parallel = False

def block_parallelize(self, device):
self.to(device)
self.model_parallel = True

def block_deparallelize(self):
self.to("cpu")
self.model_parallel = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need model_parallel as an attribute on the module itself?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if we use the automatic input remapper, since there is no other way to do if self.model_parallel but we surely will eventually remove anything that is not being used.

Actually do we really need block_deparallelize? What would be the practical use?


@model_parallel_inputs_to_device
def forward(
self,
hidden_states,
Expand Down Expand Up @@ -776,41 +791,34 @@ def __init__(self, config, embed_tokens=None):
self.dropout = nn.Dropout(config.dropout_rate)

self.init_weights()
# Model parallel

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't something like:

if self.model_parallel:
        hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value = _call__mp(hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, head_mask, past_key_value)

directly under the function signature work?

But I agree that it's a lot of boilerplate code ... -> maybe the best way is indeed to use a decorator function here.
Think in general I'm not deep enough into the model parallelism to have a good view here. Let's see what @LysandreJik thinks of the design :-).

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Dec 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see - you are suggesting to revert it to the original way where each input was manually switched to its destination device, albeit at the point of use.

There is a definite benefit to your suggestion that it's explicit, rather than magical, behind-the-scenes switch. My initial reaction was Meh, but sitting more with it, it bodes well with all the other .to switches elsewhere in MP code as the rest aren't magical. So the code reader might be puzzled at how come these aren't being switched.

The more I'm thinking about it, the more I think future pytorch might find a way to not need to manually move inputs to devices and not just with MP. It should figure it all out based on params which already know which devices they are on. And then it will be all magical.

maybe the best way is indeed to use a decorator function here.

In case it wasn't clear, my follow up suggested not to use a decorator since you decided not to use it, but an even more efficient way, which would only impact MP code

Let's see what @LysandreJik thinks.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for waiting for my input. I would prefer to go down the explicit road rather than the magical one.

Patrick's proposition seems like the most understandable to me, while keeping all the features you proposed @stas00.

The _call__mp name should probably be made more explicit, however, as right now it can't be understood from its name.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicit calls it is then, works for me.

wrt _call_mp - yeah, it has already been renamed in the Bart goes MP PR - there are a lot of moving parts. And a lot of things will be renamed at the end - these are just all temp names anyway.

self.model_parallel = False
self.device_map = None
self.first_device = "cpu"
self.last_device = "cpu"

@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.block))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
device_map = init_device_map(len(self.block), device_map)
self.first_device = f"cuda:{ list(device_map.keys())[0] }"
self.last_device = f"cuda:{ list(device_map.keys())[-1] }"
Copy link
Copy Markdown
Contributor

@alexorona alexorona Jan 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should instead put some additional work into get_device_map so it just infers based type(model):

self.device_map = get_device_map(model) if device_map is None else device_map
validate_device_map(self.device_map)

And keep the old language on first_device and last_device so we can support putting part of the model on "cpu".

self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys())) 

Copy link
Copy Markdown
Contributor Author

@stas00 stas00 Jan 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not even sure we even need first/last devices.

But yes, I think eventually each model will have to implement its own default device map. see the work I started in #9384 make_default_device_map in model_parallel_utils.py

So there will be no need to infer, since the model will just have a method that will do that.

Of course, we don't have to re-implement it for each model if it's the similar archs, so perhaps we will end up with a set of default maps and each model will pick and customize the one it needs. e.g. t5 and bart are almost the same with the exception of the name of the module list of the layers/blocks.

wrt to some layers on cpu - you believe it'd be necessary to support that? if so then perhaps it should be handled transparently along with the gpu ids. and no special case required for it, so that you could use

{ 'cpu': [1,2],
  0:     [3, 4, 5],
}

that's actually a cool thing since then one could play with MP with just one gpu.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opposition to update from get_device_map to init_device_map, but it would be nice to keep CPU support here if possible.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not a rename, it's a change in functionality - where the default and check are merged into one function to avoid pointless code repetition - doesn't contribute to easier understanding.

and yes, point taken on supporting cpu.

I propose the first stage is to polish the API for gpus-only and then definitely include cpu. as it'd be helpful for developing/debugging MP with just one gpu.

I will start making todo items.

# Load onto devices
for k, v in self.device_map.items():
for k, v in device_map.items():
for layer in v:
cuda_device = "cuda:" + str(k)
self.block[layer] = self.block[layer].to(cuda_device)

# Set embed_tokens to first layer
self.embed_tokens = self.embed_tokens.to(self.first_device)
# Set final layer norm to last device
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
self.block[layer].block_parallelize(f"cuda:{k}")
self.embed_tokens.to(self.first_device)
self.final_layer_norm.to(self.last_device)
self.model_parallel = True

@add_start_docstrings(PARALLELIZE_DOCSTRING)
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.model_parallel = False
self.device_map = None
self.to("cpu")
self.first_device = "cpu"
self.last_device = "cpu"
for i in range(len(self.block)):
self.block[i] = self.block[i].to("cpu")
self.embed_tokens = self.embed_tokens.to("cpu")
self.final_layer_norm = self.final_layer_norm.to("cpu")
self.block[i].block_deparallelize()
self.embed_tokens.to("cpu")
self.final_layer_norm.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()

def get_input_embeddings(self):
Expand All @@ -833,10 +841,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(self.first_device)
self.embed_tokens = self.embed_tokens.to(self.first_device)
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -904,20 +908,6 @@ def forward(
hidden_states = self.dropout(inputs_embeds)

for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if position_bias is not None:
position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
if encoder_extended_attention_mask is not None:
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

Expand Down Expand Up @@ -952,12 +942,6 @@ def forward(
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))

hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)

Expand Down Expand Up @@ -1161,30 +1145,22 @@ def __init__(self, config: T5Config):

self.init_weights()

# Model parallel
self.model_parallel = False
self.device_map = None
self.main_device = None

@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
device_map = init_device_map(len(self.encoder.block), device_map)
self.encoder.parallelize(device_map)
self.decoder.parallelize(device_map)
self.main_device = self.encoder.first_device
self.model_parallel = True

@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.decoder = self.decoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()

def get_input_embeddings(self):
Expand Down Expand Up @@ -1265,18 +1241,6 @@ def forward(
)

hidden_states = encoder_outputs[0]
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

# Decode
decoder_outputs = self.decoder(
Expand All @@ -1296,6 +1260,11 @@ def forward(
if not return_dict:
return decoder_outputs + encoder_outputs

if self.model_parallel:
encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device(
self.main_device, encoder_outputs, decoder_outputs
)

return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
Expand Down Expand Up @@ -1341,32 +1310,26 @@ def __init__(self, config):

self.init_weights()

# Model parallel
self.model_parallel = False
self.device_map = None
self.first_device = None
self.last_device = None

@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.decoder.first_device)
device_map = init_device_map(len(self.encoder.block), device_map)
self.encoder.parallelize(device_map)
self.decoder.parallelize(device_map)
self.first_device = self.encoder.first_device
self.last_device = self.decoder.last_device
self.lm_head.to(self.first_device)
self.model_parallel = True

@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.decoder = self.decoder.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.lm_head.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()

def get_input_embeddings(self):
Expand Down Expand Up @@ -1456,9 +1419,6 @@ def forward(

hidden_states = encoder_outputs[0]

if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)

if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
Expand All @@ -1472,17 +1432,6 @@ def forward(
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
Expand All @@ -1498,13 +1447,13 @@ def forward(
return_dict=return_dict,
)

sequence_output = decoder_outputs[0]

# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device(
self.first_device, encoder_outputs, decoder_outputs
)
# self.lm_head.to(self.first_device)

sequence_output = decoder_outputs[0]

if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
Expand Down Expand Up @@ -1596,23 +1545,18 @@ def __init__(self, config: T5Config):

self.init_weights()

self.model_parallel = False

@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
device_map = init_device_map(len(self.encoder.block), device_map)
self.encoder.parallelize(device_map)
self.model_parallel = True

@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()

def get_input_embeddings(self):
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ def __init__(
)
self.model_init = model_init

if self.args.model_parallel:
if model.is_parallelizable:
model.parallelize()
else:
raise ValueError(
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
)

default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
Expand Down
Loading