diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ce2be3c62ac..b9260398cee4 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -785,17 +785,15 @@ def parallelize(self, device_map=None): ) assert_device_map(self.device_map, len(self.block)) self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.first_device = "cpu" if "cpu" in self.device_map.keys() else f"cuda:{ list(self.device_map.keys())[0] }" + self.last_device = f"cuda:{ list(self.device_map.keys())[-1] }" # Load onto devices for k, v in self.device_map.items(): for layer in v: cuda_device = "cuda:" + str(k) self.block[layer] = self.block[layer].to(cuda_device) - # Set embed_tokens to first layer self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device self.final_layer_norm = self.final_layer_norm.to(self.last_device) @add_start_docstrings(PARALLELIZE_DOCSTRING) @@ -833,7 +831,6 @@ def forward( # Model parallel if self.model_parallel: torch.cuda.set_device(self.first_device) - self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -904,9 +901,7 @@ def forward( # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + # Ensure that the layer_module args are on the same device as hidden_states if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -915,6 +910,9 @@ def forward( encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if all_hidden_states is not None: + all_hidden_states = all_hidden_states.to(hidden_states.device) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -951,9 +949,10 @@ def forward( # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) + devices = list(self.device_map.keys()) + for e, d in enumerate(devices): + if i == self.device_map[d][-1] and f"cuda:{d}" != self.last_device: + hidden_states = hidden_states.to(f"cuda:{devices[e+1]}") hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states)