Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/transformers/models/phi3/configuration_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Phi3Config(PretrainedConfig):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2.
bos_token_id (`int`, *optional*, defaults to 1):
Expand Down Expand Up @@ -155,6 +155,7 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_adjustment()
self._rope_scaling_validation()
self.sliding_window = sliding_window

Expand All @@ -166,6 +167,19 @@ def __init__(
**kwargs,
)

def _rope_scaling_adjustment(self):
"""
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
"""
if self.rope_scaling is None:
return

rope_scaling_type = self.rope_scaling.get("type", None)

# For backward compatibility if previous version used "su" or "yarn"
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
self.rope_scaling["type"] = "longrope"

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand All @@ -181,8 +195,8 @@ def _rope_scaling_validation(self):
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
Expand Down
58 changes: 51 additions & 7 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -124,6 +125,51 @@ def forward(self, x, position_ids, seq_len=None):

class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
warnings.warn(
"The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
" use Phi3LongRoPEScaledRotaryEmbedding instead.",
Comment on lines +129 to +130
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add by setting rope_typetolong_rope`

FutureWarning,
)
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)

self.short_factor = config.rope_scaling["short_factor"]
self.long_factor = config.rope_scaling["long_factor"]
self.original_max_position_embeddings = config.original_max_position_embeddings

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
warnings.warn(
"The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
FutureWarning,
)
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)

self.short_factor = config.rope_scaling["short_factor"]
Expand Down Expand Up @@ -156,14 +202,14 @@ def forward(self, x, position_ids, seq_len=None):
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
scaling_factor = 0.1 * math.log(scale) + 1.0

cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)

Expand Down Expand Up @@ -197,7 +243,7 @@ def forward(self, x, position_ids, seq_len=None):
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = 0.1 * math.log(scale) + 1.0
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))

cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
Expand Down Expand Up @@ -318,10 +364,8 @@ def _init_rope(self):
)
else:
scaling_type = self.config.rope_scaling["type"]
if scaling_type == "su":
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
elif scaling_type == "yarn":
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
if scaling_type == "longrope":
self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
2 changes: 1 addition & 1 deletion tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_phi3_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@parameterized.expand([("su",), ("yarn",)])
@parameterized.expand([("longrope",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
Expand Down