diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 4e6e7cd8ab6d..72e07ea82467 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -101,18 +101,25 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base self.register_buffer("inv_freq", None, persistent=False) + @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if self.inv_freq is None: self.inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) ) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -1082,7 +1089,7 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - + logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1f9ee6bb1a56..0179f370ca0c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -126,6 +126,7 @@ def cos_cached(self): ) return self._cos_cached + @torch.no_grad() def forward(self, x, position_ids, seq_len=None): if seq_len is not None: logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.") @@ -133,9 +134,16 @@ def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):