From 96881bffe2d5d0e61a7ba892cdb418c0f5c87384 Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Fri, 14 Jun 2024 16:32:00 -0700 Subject: [PATCH 1/5] renamed phi3 rope_scaling type --- .../models/phi3/configuration_phi3.py | 22 +++++++-- src/transformers/models/phi3/modeling_phi3.py | 49 ++----------------- 2 files changed, 21 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 0e80566f5455..4e2bf3b4c0fb 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -78,8 +78,8 @@ class Phi3Config(PretrainedConfig): The base period of the RoPE embeddings. rope_scaling (`dict`, *optional*): The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must - contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and - the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size divided by the number of attention heads divided by 2. bos_token_id (`int`, *optional*, defaults to 1): The id of the "beginning-of-sequence" token. @@ -155,6 +155,7 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self._rope_scaling_adjustment() self._rope_scaling_validation() self.sliding_window = sliding_window @@ -166,6 +167,19 @@ def __init__( **kwargs, ) + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -181,8 +195,8 @@ def _rope_scaling_validation(self): rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]: - raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}") + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") if not ( isinstance(rope_scaling_short_factor, list) and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e14785bd1f8b..f7786a0bd023 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -120,7 +120,7 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): +class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding): def __init__(self, dim, config, device=None): super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) @@ -161,47 +161,6 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): - def __init__(self, dim, config, device=None): - super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) - - self.short_factor = config.rope_scaling["short_factor"] - self.long_factor = config.rope_scaling["long_factor"] - self.original_max_position_embeddings = config.original_max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - seq_len = torch.max(position_ids) + 1 - if seq_len > self.original_max_position_embeddings: - ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) - else: - ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) - - inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim - self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) - - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - - scale = self.max_position_embeddings / self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 - else: - scaling_factor = 0.1 * math.log(scale) + 1.0 - - cos = emb.cos() * scaling_factor - sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -316,10 +275,8 @@ def _init_rope(self): ) else: scaling_type = self.config.rope_scaling["type"] - if scaling_type == "su": - self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) - elif scaling_type == "yarn": - self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config) + if scaling_type == "longrope": + self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") From e4a5e9d053ae0887ceea56a7981edf9f3c2780b2 Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Fri, 14 Jun 2024 17:19:08 -0700 Subject: [PATCH 2/5] fixed trailing whitespaces --- src/transformers/models/phi3/configuration_phi3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 4e2bf3b4c0fb..5d6f9f389885 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -78,8 +78,8 @@ class Phi3Config(PretrainedConfig): The base period of the RoPE embeddings. rope_scaling (`dict`, *optional*): The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must - contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and - the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size divided by the number of attention heads divided by 2. bos_token_id (`int`, *optional*, defaults to 1): The id of the "beginning-of-sequence" token. @@ -179,7 +179,7 @@ def _rope_scaling_adjustment(self): # For backward compatibility if previous version used "su" or "yarn" if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: self.rope_scaling["type"] = "longrope" - + def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. From f9d19887a87c78da81fe0d92c4c56cbe710b4dda Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Fri, 14 Jun 2024 17:32:59 -0700 Subject: [PATCH 3/5] fixed test --- tests/models/phi3/test_modeling_phi3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index ad9c4c46aa93..1ddc73961bfe 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -362,7 +362,7 @@ def test_phi3_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @parameterized.expand([("su",), ("yarn",)]) + @parameterized.expand([("longrope",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) From 7ffa952b83bf667c5a1330cf70b15797c4ed9c04 Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Wed, 10 Jul 2024 11:54:45 -0700 Subject: [PATCH 4/5] added warning --- src/transformers/models/phi3/modeling_phi3.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 810dc5c3e68f..fd637ad68782 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -17,6 +17,7 @@ import inspect import math +import warnings from typing import List, Optional, Tuple, Union import torch @@ -122,6 +123,91 @@ def forward(self, x, position_ids, seq_len=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + warnings.warn( + "The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please" + " use Phi3LongRoPEScaledRotaryEmbedding instead.", + FutureWarning, + ) + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + warnings.warn( + "The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers", + FutureWarning, + ) + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = 0.1 * math.log(scale) + 1.0 + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding): def __init__(self, dim, config, device=None): super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) From 0de6511df9a35bce8ceed3d149ec871e6723bcab Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Wed, 10 Jul 2024 12:04:19 -0700 Subject: [PATCH 5/5] fixed format --- src/transformers/models/phi3/modeling_phi3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index fd637ad68782..7a047fcbe061 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -135,6 +135,7 @@ def __init__(self, dim, config, device=None): self.short_factor = config.rope_scaling["short_factor"] self.long_factor = config.rope_scaling["long_factor"] self.original_max_position_embeddings = config.original_max_position_embeddings + @torch.no_grad() def forward(self, x, position_ids, seq_len=None): seq_len = torch.max(position_ids) + 1