Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
31aad90
add: 2 variants of multi query implementation; printing some details
bigximik Aug 31, 2022
9c13b66
Unpin PyTorch
sgugger Nov 1, 2022
1ebb3f7
Release v4.24.0
sgugger Nov 1, 2022
0e654e0
Added onnx config whisper (#19525)
mht-sharma Nov 1, 2022
502d3b6
Remove pin temporarily to get tests
sgugger Nov 1, 2022
8f95346
Add ESMFold code sample (#20000)
Rocketknight1 Nov 1, 2022
94b3f54
Unpin PyTorch for the release
sgugger Nov 1, 2022
a367bc0
Merge tag 'v4.24.0' into mayank/multi_query
mayank31398 Dec 4, 2022
b7e2124
fix saving
mayank31398 Dec 5, 2022
5171b4f
Merge branch 'main' of github.com:bigcode-collaboration/transformers …
bigximik Jan 20, 2023
357ba81
added Raymond MQA variant
bigximik Jan 20, 2023
a96771f
chg: tensor vs fill acc to comments by Joel
bigximik Jan 23, 2023
14f2249
Style and fix
jlamypoirier Jan 23, 2023
4dc821c
cleanup
jlamypoirier Jan 23, 2023
129e8c9
cleanup
jlamypoirier Jan 23, 2023
303e1b8
cleanup
jlamypoirier Jan 23, 2023
d0b58e9
cleanup
jlamypoirier Jan 23, 2023
a1e9182
Fixes and cleanup
jlamypoirier Jan 23, 2023
a57ca7a
Fixes and cleanup
jlamypoirier Jan 23, 2023
e152e94
Fixes and merge implementations
jlamypoirier Jan 23, 2023
2e32a95
Fixes and improvements
jlamypoirier Jan 24, 2023
93b42d2
simplify and fix
jlamypoirier Jan 24, 2023
82b11df
Fixes, optimization and comments
jlamypoirier Jan 25, 2023
98319da
Best GeLU]
jlamypoirier Jan 26, 2023
d81e46f
simpler gelu
jlamypoirier Jan 27, 2023
33ae645
Merge branch 'main' into joel-mqa
jlamypoirier Feb 7, 2023
2f299a2
Merge branch 'gpt2_bigcode' into joel-mqa
jlamypoirier Feb 7, 2023
a1d7a95
Move code
jlamypoirier Feb 7, 2023
138fefb
fix
jlamypoirier Feb 7, 2023
732e447
Merge branch 'gpt2_bigcode' into joel-mqa
jlamypoirier Feb 7, 2023
52a6d97
fix
jlamypoirier Feb 7, 2023
b4c9cf4
fix
jlamypoirier Feb 7, 2023
3dd5a5b
gelu
jlamypoirier Feb 7, 2023
2c13d08
Add changes from fast inferences
jlamypoirier Feb 7, 2023
e8a024c
Merge branch 'gpt2_bigcode' into joel-mqa
jlamypoirier Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@
logger = logging.get_logger(__name__)


class PytorchGELUTanh(nn.Module):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
https://arxiv.org/abs/1606.08415.
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
match due to rounding errors.
"""

def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
"PytorchGELUTanh. Please upgrade torch."
)

def forward(self, input: Tensor) -> Tensor:
return nn.functional.gelu(input, approximate="tanh")


class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
Expand Down Expand Up @@ -80,10 +100,8 @@ class ClippedGELUActivation(nn.Module):
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.

Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.

For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
Expand Down Expand Up @@ -155,6 +173,7 @@ def __getitem__(self, key):
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
""" OpenAI GPT-2 configuration"""
from collections import OrderedDict
from enum import Enum
from typing import Any, List, Mapping, Optional

from transformers import PreTrainedTokenizer, TensorType, is_torch_available
Expand All @@ -31,6 +32,12 @@
}


class AttentionType(Enum):
MULTI_HEAD = 1
MULTI_QUERY_1 = 2
MULTI_QUERY_2 = 3


class GPTBigCodeConfig(PretrainedConfig):
"""
# TODO: Update doc
Expand Down Expand Up @@ -143,7 +150,7 @@ def __init__(
n_layer=12,
n_head=12,
n_inner=None,
activation_function="gelu_new",
activation_function="gelu_pytorch_tanh",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
Expand All @@ -160,6 +167,7 @@ def __init__(
eos_token_id=50256,
scale_attn_by_inverse_layer_idx=False,
reorder_and_upcast_attn=False,
attention_type=AttentionType.MULTI_HEAD,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -187,6 +195,9 @@ def __init__(
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

# Convert to an int so it's JSON-serializable.
self.attention_type = AttentionType(attention_type).value

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)


Expand Down
189 changes: 97 additions & 92 deletions src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
replace_return_docstrings,
)
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_gpt_bigcode import GPTBigCodeConfig
from .configuration_gpt_bigcode import AttentionType, GPTBigCodeConfig


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -121,16 +121,21 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):

max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
"bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
# We don't use a buffer because the mask value depends on the dtype,
# And the dtype will be different if upcasting.
self.mask_value = None

self.attention_type = AttentionType(config.attention_type)
self.is_mqa = self.attention_type != AttentionType.MULTI_HEAD

self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.kv_heads = 1 if self.is_mqa else self.head_dim
self.kv_dim = self.kv_heads * self.head_dim
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
Expand All @@ -146,11 +151,27 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

self.scale_factor = 1.0
if self.scale_attn_weights:
self.scale_factor /= self.head_dim**0.5

if self.scale_attn_by_inverse_layer_idx:
self.scale_factor /= self.layer_idx + 1

if self.is_cross_attention:
if self.is_mqa:
raise NotImplementedError(f"attention_type {self.attention_type} for cross_attention")

self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
if self.attention_type == AttentionType.MULTI_QUERY_2:
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
# Keys and values are shared across heads
self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim)
else:
self.c_attn = Conv1D(self.embed_dim + 2 * self.kv_dim, self.embed_dim)

self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
Expand All @@ -173,27 +194,52 @@ def prune_heads(self, heads):
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
def _matmul(self, x, y, dtype=None, scale_factor=1.0):
output_shape = (*x.size()[:-1], y.size(-1))
if self.is_mqa:
# Q x K: (b, sq, nh, hs) x (b, hs, sk) -> (b, sq, nh, sk)
# A X V: (b, sq, nh, sk) x (b, sk, hs) -> (b, sq, nh, hs)
output_view = (x.size(0), x.size(1) * x.size(2), y.size(-1))
# No copy needed for MQA 2, or when layer_past is provided.
x = x.reshape(*output_view[:-1], x.size(-1))
else:
# Q x K: (b, nh, sq, hs) x (b, nh, hs, sk) -> (b, nh, sq, sk)
# A X V: (b, nh, sq, sk) x (b, nh, sk, hs) -> (b, nh, sq, hs)
output_view = (x.size(0) * x.size(1), x.size(2), y.size(-1))
# Always copies
x = x.reshape(output_view[0], *x.size()[2:])
# No copy when layer_past is provided.
y = y.reshape(output_view[0], *y.size()[2:])
# This is identical to matmul when scale_factor==1
z = torch.empty(output_view, dtype=x.dtype if dtype is None else dtype, device=x.device)
z = torch.baddbmm(z, x, y, beta=0, alpha=scale_factor)
return z.view(output_shape)

def _attn(self, query, key, value, attention_mask=None, head_mask=None, upcast=False):
with autocast(enabled=False):
attn_weights = self._matmul(
query, key.transpose(-1, -2), dtype=torch.float32 if upcast else None, scale_factor=self.scale_factor
)

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
key_length = key.size(-2)
if self.is_mqa:
# (b, sq, nh, sk)
causal_mask = self.bias[None, key_length - query.size(1) : key_length, None, :key_length]
else:
# (b, nh, sq, sk)
causal_mask = self.bias[None, None, key_length - query.size(-2) : key_length, :key_length]
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
if (
self.mask_value is None
or self.mask_value.dtype != attn_weights.dtype
or self.mask_value.device != attn_weights.device
):
self.mask_value = torch.full(
[], torch.finfo(attn_weights.dtype).min, dtype=attn_weights.dtype, device=attn_weights.device
)
attn_weights = torch.where(causal_mask, attn_weights, self.mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand All @@ -202,57 +248,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()

# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)

# Compute Scale Factor
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5

if self.scale_attn_by_inverse_layer_idx:
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1)

# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
if attn_weights.dtype != torch.float32:
if upcast and attn_weights.dtype != torch.float32:
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
Expand All @@ -261,39 +257,42 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)
attn_output = self._matmul(attn_weights, value)

return attn_output, attn_weights

def _split_heads(self, tensor, num_heads, attn_head_size):
def _split_heads(self, tensor, num_heads, attn_head_size, permute=True):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
if permute:
tensor = tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
return tensor

def _merge_heads(self, tensor, num_heads, attn_head_size):
def _merge_heads(self, tensor, num_heads, attn_head_size, permute=True):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
if permute:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
Expand All @@ -303,11 +302,16 @@ def forward(
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
if self.attention_type == AttentionType.MULTI_QUERY_2:
query = self.q_attn(hidden_states)
key, value = self.kv_attn(hidden_states).split((self.kv_dim, self.kv_dim), dim=2)
else:
query, key, value = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
query = self._split_heads(query, self.num_heads, self.head_dim, permute=not self.is_mqa)
if not self.is_mqa:
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
Expand All @@ -319,12 +323,11 @@ def forward(
else:
present = None

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask, upcast=self.reorder_and_upcast_attn
)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim, permute=not self.is_mqa)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

Expand Down Expand Up @@ -363,6 +366,8 @@ def __init__(self, config, layer_idx=None):
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

if config.add_cross_attention:
if config.attention_type != AttentionType.MULTI_HEAD:
raise NotImplementedError("Cross-attention not implemented for MQA")
self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

Expand Down