Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
"uvicorn",
"xformers==0.0.16"
]


Expand Down
73 changes: 44 additions & 29 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
)
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_gpt2 import GPT2Config
import xformers
import xformers.ops


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -121,10 +123,11 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):


class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
def __init__(self, config, is_cross_attention=False, layer_idx=None, use_xfomers=False):
super().__init__()

max_positions = config.max_position_embeddings
self.use_xfomers = use_xfomers
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
Expand Down Expand Up @@ -179,42 +182,53 @@ def prune_heads(self, heads):
self.pruned_heads = self.pruned_heads.union(heads)

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
if self.use_xfomers:
if not self.is_cross_attention:
mask = xformers.ops.LowerTriangularMask()
else:
mask = attention_mask
attn_output = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=mask, op=self.attention_op
)
attn_output = attn_output.to(query.dtype)
attn_weights = None
else:
attn_weights = torch.matmul(query, key.transpose(-1, -2))

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask

# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

attn_output = torch.matmul(attn_weights, value)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

Expand Down Expand Up @@ -1586,3 +1600,4 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)