Refactor: Use Llama RoPE implementation for Falcon#26933
Refactor: Use Llama RoPE implementation for Falcon#26933gante merged 7 commits intohuggingface:mainfrom
Conversation
+ Add copy functionalities
|
Before diving into the code, I tried the snipped you shared, but with an extra from transformers import AutoTokenizer, pipeline, set_seed
import torch
model = "tiiuae/falcon-7b"
set_seed(0)
tokenizer = AutoTokenizer.from_pretrained(model)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
)
sequences = pipe(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_length=300,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
print(f"Result: {seq['generated_text']}")There is a tiny difference in outputs, which we should try to figure out before merging (maybe I'll have some clues after looking at the diff) On Using the latest commit here, the output is 👉 identical up to the last two sentences |
|
Note: adding if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()back doesn't fix the mismatch |
There was a problem hiding this comment.
The changes look good to me, thank you for working on them!
Re apply_rotary_pos_emb, I think we can go one step further and copy it from Llama as well, if it makes downstream work easier. Looking at the code, we apply a reshape such that the first dim of QKV is batch_size * self.num_heads, run apply_rotary_pos_emb (and nothing for alibi), and then reshape back to batch_size, self.num_heads, .... It seems like we would save two sets of reshape operations if we do it the Llama way 🤔
In addition to this, we should assess:
a) if this minor mismatch is relevant, perhaps by measuring ppl on some test set
b) if there are throughput changes
|
All of that sounds great. I didn't notice the potential to remove some reshapes, that would definitely be useful. And I'll reuse my benchmarking tools that I made for Will work on this tomorrow!
|
We don't need to convert any caches anymore!
|
I have some more results @gante 🎉 I've done tests for PerplexityI've tried 3 different books, and for all of them there are some notable differences in favor of If I try a different book, again there's some notable differences in favor of Yet another book, again large differences in favor of VRAM
LatencyThe latency for v2 definitely improves over One thing: I get a test failure locally for This also simplifies implementing the Attention Sink Cache for Falcon 🎉
|
|
I'll pull this into #26681, i.e. the caching refactor PR, when this is merged. Then I can conveniently test whether the implementation is easily extended for Falcon now too. |
|
Resolved the merge conflicts introduced by #26792 by @patrickvonplaten, the latency seems slightly better than reported in #26933 (comment) now. This PR should be good to go as far as I'm concerned. Idem dito for #26929 which solves some minor issues with the Falcon config. |
|
Woah, that is a great analysis of the changes! Great work Tom 💛 |
gante
left a comment
There was a problem hiding this comment.
Less code, better modelling performance, and higher throughput -- the dream PR 💛
|
Tagging @ArthurZucker (core maintainer) for a quick final check |
|
FYI @tomaarsen: a major cause for the PPL mismatch seems to stem from how The following |
|
That must be it indeed! In transformers/src/transformers/models/falcon/modeling_falcon.py Lines 247 to 249 in 576e282 which is called like so: transformers/src/transformers/models/falcon/modeling_falcon.py Lines 262 to 267 in 576e282 which is called like so: transformers/src/transformers/models/falcon/modeling_falcon.py Lines 278 to 280 in 576e282 So, if the model is loaded in fp16, then so is the query, which gives t = torch.arange(seq_len, device=device).to(torch.float16). This results in:
>>> torch.arange(3000, dtype=torch.float16)[-10:]
tensor([2990., 2992., 2992., 2992., 2994., 2996., 2996., 2996., 2998., 3000.],
dtype=torch.float16)Alternatively, if the model is loaded in bf16, then it uses >>> torch.arange(3000, dtype=torch.bfloat16)[-10:]
tensor([2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992.],
dtype=torch.bfloat16)which is just awful! With this PR, we get: transformers/src/transformers/models/falcon/modeling_falcon.py Lines 256 to 258 in dcde537 and transformers/src/transformers/models/falcon/modeling_falcon.py Lines 248 to 249 in dcde537 Which is always float32 due to the .float() call. So, we get t = torch.arange(seq_len, device=device).to(torch.float32) which results in:
>>> torch.arange(3000, dtype=torch.float32)[-10:]
tensor([2990., 2991., 2992., 2993., 2994., 2995., 2996., 2997., 2998., 2999.])So, this PR solves a hidden bug that has been resulting in reduced performance at higher sequence lengths. After all, the problem only gets worse at higher seq lengths, e.g.: >>> torch.arange(30000, dtype=torch.bfloat16)[-10:]
tensor([29952., 29952., 29952., 29952., 29952., 29952., 29952., 29952., 29952.,
29952.], dtype=torch.bfloat16)
|
|
I've resolved the outstanding merge conflicts. This should be ready for final review @ArthurZucker. I've verified that the results are identical to what I previously plotted here. I want to point out: This PR is completely unrelated to attention sinks (although it will eventually help with the implementation). These findings are for regular, pure Using these scripts you can reproduce my findings: perplexity.py"""
Adapted from https://github.com/mit-han-lab/streaming-llm
Note: Although this script measures latency, it is not optimized whatsoever!
The latency is only tracked to see the impact of speed over time.
Usage:
python benchmark/perplexity.py --experiment attention_sinks
python benchmark/perplexity.py --experiment transformers
python benchmark/perplexity.py --experiment windowed
"""
import argparse
import itertools
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import pandas as pd
import torch
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
def compute_perplexity(
model,
tokenizer,
dataset,
experiment: str,
output_dir: str = "outputs",
data_column: str = "text",
num_samples: int = 1,
num_tokens: Optional[int] = None,
overwrite: bool = False,
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"{experiment}.csv"
if output_file.exists() and not overwrite:
raise ValueError(
f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`."
)
logs = defaultdict(list)
loss_fn = CrossEntropyLoss(reduction="none")
past_key_values = None
num_processed_tokens = 0
for text in itertools.islice(dataset, num_samples):
encodings = tokenizer(text[data_column], return_tensors="pt")
seq_len = encodings.input_ids.size(1)
print(f"sequence length: {seq_len}")
pbar = tqdm(range(0, seq_len - 1))
for idx in pbar:
start_t = time.time()
input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device)
with torch.no_grad():
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
neg_log_likelihood = loss_fn(logits, label)
perplexity = neg_log_likelihood.exp()
pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}")
# Store data and save every 10 tokens
logs["input_length"].append(idx + 1)
logs["nll"].append(neg_log_likelihood.item())
logs["ppl"].append(perplexity.item())
logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item())
logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024) # in GB
logs["latency"].append(time.time() - start_t)
if num_processed_tokens % 10 == 0:
try:
pd.DataFrame(logs).to_csv(output_file, index=False)
except KeyboardInterrupt as ex:
# If there's a Keyboard Interrupt, still write the file, and then stop
pd.DataFrame(logs).to_csv(output_file, index=False)
raise ex
num_processed_tokens += 1
if num_tokens and num_processed_tokens >= num_tokens:
return
def main():
parser = argparse.ArgumentParser()
# How to call this experiment?
parser.add_argument(
"--experiment", type=str, default="main"
)
# Model args
parser.add_argument("--model_name_or_path", type=str, default="tiiuae/falcon-7b")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--trust_remote_code", action="store_true")
# Dataset args
parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test")
parser.add_argument("--data_column", type=str, default="text")
parser.add_argument("--task", type=str, default=None)
parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
# parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--num_tokens", type=int, default=5000)
# Where to log
parser.add_argument("--output_dir", type=str, default="perplexity_benchmark")
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.revision,
trust_remote_code=bool(args.trust_remote_code),
torch_dtype=torch.float16,
device_map="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))
# Set up the dataset
dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True)
compute_perplexity(
model,
tokenizer,
dataset,
args.experiment,
output_dir=args.output_dir,
data_column=args.data_column,
num_samples=1, # <- No support for more than one instance now
num_tokens=args.num_tokens,
overwrite=args.overwrite,
)
if __name__ == "__main__":
main()plot_perplexity.py"""
First run `perplexity.py` to generate one or more `csv` files.
This script can plot those csv files.
Usage:
python benchmark/plot_perplexity.py
python benchmark/plot_perplexity.py --features perplexity latency --title "Log perplexity & latency of Llama 2 7B as a function of input lengths"
"""
import argparse
from pathlib import Path
from typing import List, Optional
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
FEATURE_DF_MAP = {
"perplexity": "overall_ppl",
"vram": "cuda_vram_allocated",
"latency": "latency",
}
FEATURE_STYLE_MAP = {
"perplexity": "-",
"vram": "--",
"latency": ":",
}
FEATURE_LABEL_MAP = {
"perplexity": "Perplexity (log), lower is better",
"vram": "CUDA VRAM Usage (GB), lower is better",
"latency": "Time per token (sec), lower is better",
}
def plot(
features: List[str],
output_dir: str = "outputs",
title: Optional[str] = None,
perplexity_limit: Optional[float] = None,
skip_first: int = 100,
):
output_dir = Path(output_dir)
fig, ax = plt.subplots()
ax.set_xlabel("Input Sequence Length")
for feature_i, feature in enumerate(features):
# If we already plotted on this ax, make a new one
if feature_i:
ax = ax.twinx()
for file in output_dir.glob("*.csv"):
experiment = file.stem
df = pd.read_csv(file)
X = df["input_length"][skip_first:]
Y = df[FEATURE_DF_MAP[feature]][skip_first:]
if feature == "perplexity":
Y = np.log(Y)
if feature == "latency":
poly = np.polyfit(X, Y, 20)
poly_y = np.poly1d(poly)(X)
ax.plot(X, poly_y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
else:
ax.plot(X, Y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
ax.set_ylabel(FEATURE_LABEL_MAP[feature])
if perplexity_limit and feature == "perplexity":
ax.set_ylim(top=min(ax.get_ylim()[1], perplexity_limit))
ax.legend(loc=[1, 2, 7][feature_i]) # upper right, upper left, center right
ax.set_title(title.replace("\\n", "\n") if title else "Log perplexity as a function of input lengths")
fig.tight_layout()
return fig
def main():
parser = argparse.ArgumentParser()
# Where csv files have been logged
parser.add_argument("--output_dir", type=str, default="perplexity_benchmark")
parser.add_argument(
"--features", choices=["perplexity", "vram", "latency"], nargs="+", default=["perplexity", "vram"]
)
parser.add_argument("--title", type=str, default=None)
parser.add_argument("--log_perplexity_limit", type=float, default=5.0)
# Perplexity starts a bit unstable, so we skip the start
parser.add_argument("--skip_first", type=int, default=100)
args = parser.parse_args()
figure = plot(
args.features,
output_dir=args.output_dir,
title=args.title,
perplexity_limit=args.log_perplexity_limit,
skip_first=args.skip_first,
)
# Add your own code here if you'd like to change the figure
plt.show()
if __name__ == "__main__":
main()Usage: git checkout main
# the --experiment just determines the filename
python ./perplexity.py --experiment main
git checkout pr-26933 # <- Or whatever branch you use locally for this PR
python ./perplexity.py --experiment llama_rope_for_falcon
python ./plot_perplexity.pyAnd you'll get a plot just like here.
|
|
The documentation is not available anymore as the PR was closed or merged. |
|
(updating core maintainer to review :) ) |
amyeroberts
left a comment
There was a problem hiding this comment.
Wow - a really great PR! 🔥
Thank you for all the work in refactoring this, writing up such detailed explanations and all of the investigative work.
* Use Llama RoPE implementation for Falcon + Add copy functionalities * Use standard cache format for Falcon * Simplify apply_rotary_pos_emb, copy from Llama * Remove unnecessary cache conversion test We don't need to convert any caches anymore! * Resolve copy complaint







What does this PR do?
transformers/src/transformers/models/falcon/modeling_falcon.py
Line 91 in ad08137
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Discussed this internally with @gante.
Details
There's a few differences between Llama and Falcon that complicate this somewhat. In particular, Llama deals with
[batch_size, num_(kv)_heads, seq_len, head_dim]on the KVQ states, while Falcon uses[batch_size * num_(kv)_heads, seq_len, head_dim], i.e. one dimension less.This is why
apply_rotary_pos_embusestorch.repeat_interleavea few times to manually expand the cos, sin when necessary. This used to be in theforwardof theFalconRotaryEmbedding.There are still some differences between the old and new implementations:
.cos()and.sin()in float32:transformers/src/transformers/models/falcon/modeling_falcon.py
Lines 115 to 116 in ad08137
forwardcall, while the Llama RoPE doesn't:transformers/src/transformers/models/falcon/modeling_falcon.py
Lines 131 to 133 in ad08137
(Should we also implement this on Llama? In case someone wants to move a model on the fly)
In the context of Attention Sinks
For Attention Sinks (context: #26681), it looks like the SinkCache must store a
apply_rotary_pos_embvariable or something - because for Falcon it will need to use a differentapply_rotary_pos_embfunction.How did I test?
I tested this by running
pytest tests/models/falconandand observing that the generated text was:
Which is of similar quality of the
mainbranch:Note: I didn't run this with Falcon 40b! I would heavily recommend modifying the above script with 40b and ensuring that it runs correctly with that model too. Falcon 40b uses
"new_decoder_architecture": truewhile 7b uses"new_decoder_architecture": false.Who can review?
@gante