Skip to content

fix static cache data type miss-match#34799

Merged
ArthurZucker merged 14 commits intohuggingface:mainfrom
jiqing-feng:gptj
Nov 25, 2024
Merged

fix static cache data type miss-match#34799
ArthurZucker merged 14 commits intohuggingface:mainfrom
jiqing-feng:gptj

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented Nov 19, 2024

Hi @SunMarc . This PR fixed the data type mismatch when using low-precision static cache. The following code can reproduce the bug:

import torch
from transformers import pipeline

model_id = "EleutherAI/gpt-j-6b"
model_kwargs = {"torch_dtype": torch.bfloat16}

pipe = pipeline("text-generation", model=model_id, model_kwargs=model_kwargs)

generation_config = pipe.model.generation_config
generation_config.cache_implementation="static"

print(pipe("I am happy because", generation_config=generation_config))

Output:

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Traceback (most recent call last):
  File "/home/jiqingfe/test_gptj.py", line 13, in <module>
    print(pipe("I am happy because", generation_config=generation_config))
  File "/home/jiqingfe/transformers/src/transformers/pipelines/text_generation.py", line 272, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/base.py", line 1301, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/base.py", line 1308, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/base.py", line 1208, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/text_generation.py", line 370, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/generation/utils.py", line 2263, in generate
    result = self._beam_search(
  File "/home/jiqingfe/transformers/src/transformers/generation/utils.py", line 3472, in _beam_search
    outputs = self(**model_inputs, return_dict=True)
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/models/gptj/modeling_gptj.py", line 1098, in forward
    transformer_outputs = self.transformer(
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/models/gptj/modeling_gptj.py", line 838, in forward
    outputs = block(
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/models/gptj/modeling_gptj.py", line 453, in forward
    attn_outputs = self.attn(
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/idp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/models/gptj/modeling_gptj.py", line 246, in forward
    key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
  File "/home/jiqingfe/transformers/src/transformers/cache_utils.py", line 1220, in update
    k_out.index_copy_(2, cache_position, key_states)
RuntimeError: index_copy_(): self and source expected to have the same dtype, but got (self) BFloat16 and (source) Float

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

BTW, I suppose transformers missed some static cache tests, do you have any instructions about where can I add this kind of test? Thanks!

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Nov 19, 2024

BTW, I suppose transformers missed some static cache tests, do you have any instructions about where can I add this kind of test? Thanks!

All the tests related to the cache are in test_utils.py file. inside GenerationTesterMixin, you will find the test we perform on all models and in GenerationIntegrationTests, these are integration tests.

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the bug fix ! Left a comment

Comment on lines +236 to +237
key = key.permute(0, 2, 1, 3).to(value.dtype)
query = query.permute(0, 2, 1, 3).to(value.dtype)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain why this is needed for this particular model and why this doesn't happen for llama for example ? Many models have approximately the same modeling code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For llama, we can see sin and cos come from position_embeddings (bf16 tensor) which comes from here. You can see llama's rotary embedding converts the data type. But for gptj, the position embeddings come from here, it set the data type to float32, so data type miss-match happens when the input data type is bf16 or fp16.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng changed the title fix gptj data type missmatch fix static cache data type missmatch Nov 20, 2024
@jiqing-feng jiqing-feng changed the title fix static cache data type missmatch fix static cache data type miss-match Nov 20, 2024
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . I left the comment to explain why llama model doesn't have this issue. BTW, I also added the low-precision static cache tests to avoid this kind of issue in the future, please review it. Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation ! LGTM ! Did you run the static cache tests you added to see if there are other models that requires this fix ?

@SunMarc SunMarc requested a review from ArthurZucker November 20, 2024 12:30
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

➕ on marc's comment. The safe way to do this is to cast key and query to the cache's dtype no? And do this in the cache_utils rather than at the modeling level!

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Nov 21, 2024

➕ on marc's comment. The safe way to do this is to cast key and query to the cache's dtype no? And do this in the cache_utils rather than at the modeling level!

Yes, I have applied your suggestions, thanks!

Thanks for the explanation ! LGTM ! Did you run the static cache tests you added to see if there are other models that requires this fix ?

The CI already contains the tests that I changed, so currently no other models require it. Besides, I have changed it into cache_utils which should be applied for all language models with static cache.

Comment thread src/transformers/cache_utils.py Outdated
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@SunMarc SunMarc requested a review from ArthurZucker November 22, 2024 16:17
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc , please review the new changes, thanks!

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Nov 25, 2024

Hi @SunMarc , please review the new changes, thanks!

All good from my side. Pinging @ArthurZucker

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely sure we want to test for float32 as it's quite heavy

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Nov 25, 2024

I think it was testing for float32 initally and @jiqing-feng added coverage for float16 @ArthurZucker

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Sounds good then, merging!˜

@ArthurZucker ArthurZucker merged commit a464afb into huggingface:main Nov 25, 2024
@jiqing-feng jiqing-feng deleted the gptj branch November 26, 2024 01:06
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* fix gptj data type missmatch

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add low precision static cache tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix low-precision static cache tests

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* avoid config change

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* change data type convert in cache copy

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* cast key value after k v out

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants