Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions benchmark/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):
with torch.no_grad():
past_key_values = StaticCache(
model.config,
max_batch_size=batch_size,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + num_tokens_to_generate,
Expand All @@ -144,7 +144,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):

past_key_values = StaticCache(
model.config,
max_batch_size=batch_size,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + num_tokens_to_generate,
Expand Down Expand Up @@ -187,7 +187,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
past_key_values = StaticCache(
model.config,
max_batch_size=batch_size,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + num_tokens_to_generate + 10,
Expand Down Expand Up @@ -254,7 +254,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):

past_key_values = StaticCache(
model.config,
max_batch_size=batch_size,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + 128,
Expand All @@ -271,7 +271,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):

past_key_values = StaticCache(
model.config,
max_batch_size=batch_size,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + 128,
Expand All @@ -287,7 +287,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):

past_key_values = StaticCache(
model.config,
max_batch_size=batch_size,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + 128,
Expand All @@ -303,7 +303,7 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):

past_key_values = StaticCache(
model.config,
max_batch_size=batch_size,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + 128,
Expand Down
6 changes: 3 additions & 3 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Init StaticCache with big enough max-length (1024 tokens for the below example)
# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
prompt_cache = StaticCache(config=model.config, batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)

INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
Expand All @@ -351,7 +351,7 @@ responses = []
for prompt in prompts:
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
response = tokenizer.batch_decode(outputs)[0]
responses.append(response)

Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
Expand Down Expand Up @@ -159,7 +159,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
Expand Down
8 changes: 4 additions & 4 deletions docs/source/en/model_doc/gemma2.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pipe("Explain quantum computing simply. ", max_new_tokens=50)

</hfoption>
<hfoption id="AutoModel">

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
Expand Down Expand Up @@ -89,7 +89,7 @@ echo -e "Explain quantum computing simply." | transformers-cli run --task text-g
</hfoptions>

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.

The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.

```python
Expand Down Expand Up @@ -118,7 +118,7 @@ Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/bl
```python
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")
visualizer("You are an assistant. Make sure you print me")
visualizer("You are an assistant. Make sure you print me")
```

<div class="flex justify-center">
Expand All @@ -137,7 +137,7 @@ visualizer("You are an assistant. Make sure you print me")

inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
max_generated_length = inputs.input_ids.shape[1] + 10
past_key_values = HybridCache(config=model.config, max_batch_size=1,
past_key_values = HybridCache(config=model.config, batch_size=1,
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/source/ko/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
batch_size=1,
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
Expand Down Expand Up @@ -161,7 +161,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
Expand Down
Loading