From 7a2572018fe6686b66876ce1ae75b99529ef86d1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 26 Feb 2024 17:04:08 +1100 Subject: [PATCH 01/17] Update modeling_llama.py Llama - Force float32 since bfloat16 loses precision on long contexts --- src/transformers/models/llama/modeling_llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 66a50c580891..0784c9266d2c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -123,7 +123,9 @@ def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + # Force float32 since bfloat16 loses precision on long contexts + with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=x.dtype) sin = emb.sin().to(dtype=x.dtype) From db8237f449eacbcc2df9a9a82a6ae77ebf76478a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 26 Feb 2024 17:06:37 +1100 Subject: [PATCH 02/17] Update modeling_llama.py --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0784c9266d2c..68a62df4b35f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -125,7 +125,7 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=x.dtype) sin = emb.sin().to(dtype=x.dtype) From 3de95c42f055a5939733e2df4f2cba4b5c758463 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 26 Feb 2024 17:08:10 +1100 Subject: [PATCH 03/17] Update modeling_gemma.py Fix RoPE and logits.float() --- src/transformers/models/gemma/modeling_gemma.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 4cb12ff47005..dc9922274976 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -110,7 +110,11 @@ def forward(self, x, position_ids, seq_len=None): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + + # Force float32 since bfloat16 loses precision on long contexts + with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) @@ -1079,7 +1083,8 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - + logits = logits.float() + loss = None if labels is not None: # Shift so that tokens < n predict n From 99d564e7e6d07aab633f3bd86d316b740d31e68d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 27 Feb 2024 19:58:54 +1100 Subject: [PATCH 04/17] @torch.no_grad() --- src/transformers/models/llama/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 249fe878e4c7..7c1e308f3882 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -116,6 +116,7 @@ def cos_cached(self): ) return self._cos_cached + @torch.no_grad() def forward(self, x, position_ids, seq_len=None): if seq_len is not None: logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.") From d0c08bf67819d7d4ba6906274f0ad7770004ae4a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 27 Feb 2024 20:00:04 +1100 Subject: [PATCH 05/17] @torch.no_grad() --- src/transformers/models/gemma/modeling_gemma.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 386da1918c5d..ed4dabc70b4b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -101,20 +101,19 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base self.register_buffer("inv_freq", None, persistent=False) + @torch.no_grad() def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if self.inv_freq is None: self.inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) ) - + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) From abffebb6822e756aef76348e96244fc21d31b579 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 Feb 2024 16:21:43 +1100 Subject: [PATCH 06/17] Cos, Sin to float32 --- src/transformers/models/gemma/modeling_gemma.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ed4dabc70b4b..a241c4e0f761 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -112,10 +112,13 @@ def forward(self, x, position_ids, seq_len=None): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half From c2e31bf466fbb73004427f3db13793fc1b3a259a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 Feb 2024 16:22:33 +1100 Subject: [PATCH 07/17] cos, sin to float32 --- src/transformers/models/llama/modeling_llama.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7c1e308f3882..a8b66b8bd988 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -125,11 +125,14 @@ def forward(self, x, position_ids, seq_len=None): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=x.dtype) - sin = emb.sin().to(dtype=x.dtype) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + cos = cos.to(dtype=x.dtype) + sin = sin.to(dtype=x.dtype) # backwards compatibility self._cos_cached = cos self._sin_cached = sin From f487800fffdef976f0b8c0923b8a592cc91a073a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 Feb 2024 21:30:19 +1100 Subject: [PATCH 08/17] Update src/transformers/models/gemma/modeling_gemma.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/gemma/modeling_gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a241c4e0f761..1e1661595fd3 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -113,7 +113,7 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): + with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() From c8526756d8915278f84f91c675651f1818df4d04 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 Feb 2024 21:30:32 +1100 Subject: [PATCH 09/17] Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a8b66b8bd988..145ac110b7da 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -126,7 +126,7 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False): + with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() From 1a50a4bceaacc64595d6efac2f28fa7b1f1c0434 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Wed, 28 Feb 2024 21:48:24 +1100 Subject: [PATCH 10/17] Resolve PR conflicts --- .../models/gemma/modeling_gemma.py | 2 - .../models/llama/modeling_llama.py | 48 ++++++++----------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 1e1661595fd3..95faa1998138 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -108,7 +108,6 @@ def forward(self, x, position_ids, seq_len=None): self.inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) ) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts @@ -1089,7 +1088,6 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - loss = None if labels is not None: # Shift so that tokens < n predict n diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 145ac110b7da..1f9ee6bb1a56 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -92,60 +92,55 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() + self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): logger.warning_once( - "The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead." + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._sin_cached @property def cos_cached(self): logger.warning_once( - "The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead." + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" ) return self._cos_cached - @torch.no_grad() def forward(self, x, position_ids, seq_len=None): if seq_len is not None: - logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.") + logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.") # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - cos = cos.to(dtype=x.dtype) - sin = sin.to(dtype=x.dtype) - # backwards compatibility - self._cos_cached = cos - self._sin_cached = sin - return cos, sin + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: a scaling factor is aplied to the position ids position_ids = position_ids.float() / self.scaling_factor @@ -156,10 +151,6 @@ def forward(self, x, position_ids, seq_len=None): class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - def forward(self, x, position_ids, seq_len=None): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length seq_len = torch.max(position_ids) + 1 @@ -373,6 +364,7 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask if cache_position is not None: causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask From b860a22dab9bb01cd15cb9a3220abeaefad3e458 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Wed, 28 Feb 2024 21:52:02 +1100 Subject: [PATCH 11/17] Fix RoPE for llama --- src/transformers/models/llama/modeling_llama.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1f9ee6bb1a56..f166e19935f1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -126,6 +126,7 @@ def cos_cached(self): ) return self._cos_cached + @torch.no_grad() def forward(self, x, position_ids, seq_len=None): if seq_len is not None: logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.") @@ -133,9 +134,19 @@ def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + cos = cos.to(dtype=x.dtype) + sin = sin.to(dtype=x.dtype) + # backwards compatibility + self._cos_cached = cos + self._sin_cached = sin + return cos, sin class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): From 790e4a3ae058d78ca114104ffd592cc656d1c79d Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Wed, 28 Feb 2024 21:52:44 +1100 Subject: [PATCH 12/17] Revert "Fix RoPE for llama" This reverts commit b860a22dab9bb01cd15cb9a3220abeaefad3e458. --- src/transformers/models/llama/modeling_llama.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f166e19935f1..1f9ee6bb1a56 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -126,7 +126,6 @@ def cos_cached(self): ) return self._cos_cached - @torch.no_grad() def forward(self, x, position_ids, seq_len=None): if seq_len is not None: logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.") @@ -134,19 +133,9 @@ def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - cos = cos.to(dtype=x.dtype) - sin = sin.to(dtype=x.dtype) - # backwards compatibility - self._cos_cached = cos - self._sin_cached = sin - return cos, sin + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): From aa03a433ed7bfaf93abf4d5701979b2006666f3c Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Wed, 28 Feb 2024 21:54:10 +1100 Subject: [PATCH 13/17] Fix RoPE for llama --- src/transformers/models/llama/modeling_llama.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1f9ee6bb1a56..f166e19935f1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -126,6 +126,7 @@ def cos_cached(self): ) return self._cos_cached + @torch.no_grad() def forward(self, x, position_ids, seq_len=None): if seq_len is not None: logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.") @@ -133,9 +134,19 @@ def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + cos = cos.to(dtype=x.dtype) + sin = sin.to(dtype=x.dtype) + # backwards compatibility + self._cos_cached = cos + self._sin_cached = sin + return cos, sin class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): From 5730a5037b496ff8267b073d3712c8a5dd9a8a61 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Wed, 28 Feb 2024 23:23:00 +1100 Subject: [PATCH 14/17] RoPE device --- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 95faa1998138..7a3bac63edfd 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -112,7 +112,7 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): + with torch.autocast(device_type=x.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f166e19935f1..3a5839a617f6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -136,7 +136,7 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=self.inv_freq.device.type, enabled=False): + with torch.autocast(device_type=x.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() From 31cea3b38f6dc4dd0a27e3b7a6d2dabb4fa2f942 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Thu, 29 Feb 2024 00:35:43 +1100 Subject: [PATCH 15/17] Autocast device type --- src/transformers/models/gemma/modeling_gemma.py | 4 +++- src/transformers/models/llama/modeling_llama.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 7a3bac63edfd..47966ac4518e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -112,7 +112,9 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=x.device.type, enabled=False): + device_type = x.device.type + device_type = device_type if type(device_type) is str else "cpu" + with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3a5839a617f6..208affbddac7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -136,7 +136,9 @@ def forward(self, x, position_ids, seq_len=None): position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 - with torch.autocast(device_type=x.device.type, enabled=False): + device_type = x.device.type + device_type = device_type if type(device_type) is str else "cpu" + with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() From ae9957f389fcf9f9586844697a35aa2e0f8b5d71 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Thu, 29 Feb 2024 00:51:42 +1100 Subject: [PATCH 16/17] RoPE --- src/transformers/models/llama/modeling_llama.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 208affbddac7..8a71dde8ff7e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -143,12 +143,7 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - cos = cos.to(dtype=x.dtype) - sin = sin.to(dtype=x.dtype) - # backwards compatibility - self._cos_cached = cos - self._sin_cached = sin - return cos, sin + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): From ec9ef17f6931ab9d5b2b0ed0aa69e1ae0c03e91f Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Thu, 29 Feb 2024 00:57:48 +1100 Subject: [PATCH 17/17] RoPE isinstance --- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 47966ac4518e..72e07ea82467 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -113,7 +113,7 @@ def forward(self, x, position_ids, seq_len=None): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if type(device_type) is str else "cpu" + device_type = device_type if isinstance(device_type, str) else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8a71dde8ff7e..0179f370ca0c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -137,7 +137,7 @@ def forward(self, x, position_ids, seq_len=None): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if type(device_type) is str else "cpu" + device_type = device_type if isinstance(device_type, str) else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1)