diff --git a/setup.py b/setup.py index d1cba0dfebf0..e145959bacc9 100644 --- a/setup.py +++ b/setup.py @@ -185,6 +185,7 @@ "unidic>=1.0.2", "unidic_lite>=1.0.7", "uvicorn", + "xformers==0.0.16" ] diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 3f1806fb9a8c..e602677d7d20 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -46,6 +46,8 @@ ) from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_gpt2 import GPT2Config +import xformers +import xformers.ops logger = logging.get_logger(__name__) @@ -121,10 +123,11 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): class GPT2Attention(nn.Module): - def __init__(self, config, is_cross_attention=False, layer_idx=None): + def __init__(self, config, is_cross_attention=False, layer_idx=None, use_xfomers=False): super().__init__() max_positions = config.max_position_embeddings + self.use_xfomers = use_xfomers self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( @@ -179,42 +182,53 @@ def prune_heads(self, heads): self.pruned_heads = self.pruned_heads.union(heads) def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + if self.use_xfomers: + if not self.is_cross_attention: + mask = xformers.ops.LowerTriangularMask() + else: + mask = attention_mask + attn_output = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=mask, op=self.attention_op ) + attn_output = attn_output.to(query.dtype) + attn_weights = None + else: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) - # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) - attn_output = torch.matmul(attn_weights, value) + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights @@ -1586,3 +1600,4 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) +