diff --git a/src/transformers/activations.py b/src/transformers/activations.py index d9caf8763e..1c59568835 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -25,6 +25,26 @@ logger = logging.get_logger(__name__) +class PytorchGELUTanh(nn.Module): + """ + A fast C implementation of the tanh approximation of the GeLU activation function. See + https://arxiv.org/abs/1606.08415. + This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical + match due to rounding errors. + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.12.0"): + raise ImportError( + f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " + "PytorchGELUTanh. Please upgrade torch." + ) + + def forward(self, input: Tensor) -> Tensor: + return nn.functional.gelu(input, approximate="tanh") + + class NewGELUActivation(nn.Module): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see @@ -80,10 +100,8 @@ class ClippedGELUActivation(nn.Module): Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to https://arxiv.org/abs/2004.09602. - Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when initially created. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 """ @@ -155,6 +173,7 @@ def __getitem__(self, key): "gelu_fast": FastGELUActivation, "gelu_new": NewGELUActivation, "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "gelu_pytorch_tanh": PytorchGELUTanh, "linear": LinearActivation, "mish": MishActivation, "quick_gelu": QuickGELUActivation, diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 8fcf554ded..6546b47253 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -15,6 +15,7 @@ # limitations under the License. """ OpenAI GPT-2 configuration""" from collections import OrderedDict +from enum import Enum from typing import Any, List, Mapping, Optional from transformers import PreTrainedTokenizer, TensorType, is_torch_available @@ -31,6 +32,12 @@ } +class AttentionType(Enum): + MULTI_HEAD = 1 + MULTI_QUERY_1 = 2 + MULTI_QUERY_2 = 3 + + class GPTBigCodeConfig(PretrainedConfig): """ # TODO: Update doc @@ -143,7 +150,7 @@ def __init__( n_layer=12, n_head=12, n_inner=None, - activation_function="gelu_new", + activation_function="gelu_pytorch_tanh", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, @@ -160,6 +167,7 @@ def __init__( eos_token_id=50256, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False, + attention_type=AttentionType.MULTI_HEAD, **kwargs, ): self.vocab_size = vocab_size @@ -187,6 +195,9 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id + # Convert to an int so it's JSON-serializable. + self.attention_type = AttentionType(attention_type).value + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 942cea744b..82938e65ab 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -44,7 +44,7 @@ replace_return_docstrings, ) from ...utils.model_parallel_utils import assert_device_map, get_device_map -from .configuration_gpt_bigcode import GPTBigCodeConfig +from .configuration_gpt_bigcode import AttentionType, GPTBigCodeConfig logger = logging.get_logger(__name__) @@ -121,16 +121,21 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): max_positions = config.max_position_embeddings self.register_buffer( - "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( - 1, 1, max_positions, max_positions - ), + "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False ) - self.register_buffer("masked_bias", torch.tensor(-1e4)) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + # We don't use a buffer because the mask value depends on the dtype, + # And the dtype will be different if upcasting. + self.mask_value = None + + self.attention_type = AttentionType(config.attention_type) + self.is_mqa = self.attention_type != AttentionType.MULTI_HEAD self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads + self.kv_heads = 1 if self.is_mqa else self.head_dim + self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -146,11 +151,27 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.layer_idx = layer_idx self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + self.scale_factor = 1.0 + if self.scale_attn_weights: + self.scale_factor /= self.head_dim**0.5 + + if self.scale_attn_by_inverse_layer_idx: + self.scale_factor /= self.layer_idx + 1 + if self.is_cross_attention: + if self.is_mqa: + raise NotImplementedError(f"attention_type {self.attention_type} for cross_attention") + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) self.q_attn = Conv1D(self.embed_dim, self.embed_dim) else: - self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + if self.attention_type == AttentionType.MULTI_QUERY_2: + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + # Keys and values are shared across heads + self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim) + else: + self.c_attn = Conv1D(self.embed_dim + 2 * self.kv_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) @@ -173,27 +194,52 @@ def prune_heads(self, heads): self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + def _matmul(self, x, y, dtype=None, scale_factor=1.0): + output_shape = (*x.size()[:-1], y.size(-1)) + if self.is_mqa: + # Q x K: (b, sq, nh, hs) x (b, hs, sk) -> (b, sq, nh, sk) + # A X V: (b, sq, nh, sk) x (b, sk, hs) -> (b, sq, nh, hs) + output_view = (x.size(0), x.size(1) * x.size(2), y.size(-1)) + # No copy needed for MQA 2, or when layer_past is provided. + x = x.reshape(*output_view[:-1], x.size(-1)) + else: + # Q x K: (b, nh, sq, hs) x (b, nh, hs, sk) -> (b, nh, sq, sk) + # A X V: (b, nh, sq, sk) x (b, nh, sk, hs) -> (b, nh, sq, hs) + output_view = (x.size(0) * x.size(1), x.size(2), y.size(-1)) + # Always copies + x = x.reshape(output_view[0], *x.size()[2:]) + # No copy when layer_past is provided. + y = y.reshape(output_view[0], *y.size()[2:]) + # This is identical to matmul when scale_factor==1 + z = torch.empty(output_view, dtype=x.dtype if dtype is None else dtype, device=x.device) + z = torch.baddbmm(z, x, y, beta=0, alpha=scale_factor) + return z.view(output_shape) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None, upcast=False): + with autocast(enabled=False): + attn_weights = self._matmul( + query, key.transpose(-1, -2), dtype=torch.float32 if upcast else None, scale_factor=self.scale_factor ) - # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) - if not self.is_cross_attention: # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + key_length = key.size(-2) + if self.is_mqa: + # (b, sq, nh, sk) + causal_mask = self.bias[None, key_length - query.size(1) : key_length, None, :key_length] + else: + # (b, nh, sq, sk) + causal_mask = self.bias[None, None, key_length - query.size(-2) : key_length, :key_length] + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if ( + self.mask_value is None + or self.mask_value.dtype != attn_weights.dtype + or self.mask_value.device != attn_weights.device + ): + self.mask_value = torch.full( + [], torch.finfo(attn_weights.dtype).min, dtype=attn_weights.dtype, device=attn_weights.device + ) + attn_weights = torch.where(causal_mask, attn_weights, self.mask_value) if attention_mask is not None: # Apply the attention mask @@ -202,57 +248,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - - def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): - # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) - bsz, num_heads, q_seq_len, dk = query.size() - _, _, k_seq_len, _ = key.size() - - # Preallocate attn_weights for `baddbmm` - attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) - - # Compute Scale Factor - scale_factor = 1.0 - if self.scale_attn_weights: - scale_factor /= float(value.size(-1)) ** 0.5 - - if self.scale_attn_by_inverse_layer_idx: - scale_factor /= float(self.layer_idx + 1) - - # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): - q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) - attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) - attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise - if attn_weights.dtype != torch.float32: + if upcast and attn_weights.dtype != torch.float32: raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) @@ -261,39 +257,42 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea if head_mask is not None: attn_weights = attn_weights * head_mask - attn_output = torch.matmul(attn_weights, value) + attn_output = self._matmul(attn_weights, value) return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size): + def _split_heads(self, tensor, num_heads, attn_head_size, permute=True): """ Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + if permute: + tensor = tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor - def _merge_heads(self, tensor, num_heads, attn_head_size): + def _merge_heads(self, tensor, num_heads, attn_head_size, permute=True): """ Merges attn_head_size dim and num_attn_heads dim into hidden_size """ - tensor = tensor.permute(0, 2, 1, 3).contiguous() + if permute: + tensor = tensor.permute(0, 2, 1, 3).contiguous() new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], ...]: if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): + if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." @@ -303,11 +302,16 @@ def forward( key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + if self.attention_type == AttentionType.MULTI_QUERY_2: + query = self.q_attn(hidden_states) + key, value = self.kv_attn(hidden_states).split((self.kv_dim, self.kv_dim), dim=2) + else: + query, key, value = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + query = self._split_heads(query, self.num_heads, self.head_dim, permute=not self.is_mqa) + if not self.is_mqa: + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past @@ -319,12 +323,11 @@ def forward( else: present = None - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) - else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask, upcast=self.reorder_and_upcast_attn + ) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim, permute=not self.is_mqa) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -363,6 +366,8 @@ def __init__(self, config, layer_idx=None): self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: + if config.attention_type != AttentionType.MULTI_HEAD: + raise NotImplementedError("Cross-attention not implemented for MQA") self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)