Skip to content

Fix tensors on "two devices" issue #32420#33742

Closed
davedgd wants to merge 1 commit intohuggingface:mainfrom
davedgd:patch-1
Closed

Fix tensors on "two devices" issue #32420#33742
davedgd wants to merge 1 commit intohuggingface:mainfrom
davedgd:patch-1

Conversation

@davedgd
Copy link
Copy Markdown

@davedgd davedgd commented Sep 27, 2024

What does this PR do?

Fixes #32420 by placing both inv_freq_expanded and position_ids_expanded on the same device. This avoids the following error on this line:

freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

Allows autoawq and other packages to correctly perform CPU offloading during quantization.

Note: I tested this using a Qwen2.5 model and was succesfully able to resolve it. In principle, only the following change appeared necessary (inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)) on line 210 here:

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)

Still, in searching through other models, I noticed differences depending on whether the # Core RoPE block code was present related to this pull (#29285), so I applied the .to(x.device) change to both inv_freq_expanded and position_ids_expanded just to be safe throughout all the models that used this code.

Fixes #32420

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr
@ArthurZucker

Fixes huggingface#32420 by placing both inv_freq_expanded and position_ids_expanded on the same device. This avoids the following error on this line:

freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

Allows autoawq and other packages to correctly perform CPU offloading during quantization.
@JeevanBhoot
Copy link
Copy Markdown

Yes this fixed my issue with autogptq: AutoGPTQ/AutoGPTQ#729
I hope this gets merged!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @SunMarc shouldn't the position ids be automatically mapped to the device of that layer with accelerate?

lgtm otherwise!

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Oct 3, 2024

Hi @davedgd , thanks for the PR ! Did you had the issue with prior version of transformers ? Could you share a reproducer of the issue ? Still I find it strange that the input of the forward are not on the same device. I think that's probably related to this PR.

@davedgd
Copy link
Copy Markdown
Author

davedgd commented Oct 3, 2024

Hi @davedgd , thanks for the PR ! Did you had the issue with prior version of transformers ? Could you share a reproducer of the issue ? Still I find it strange that the input of the forward are not on the same device. I think that's probably related to this PR.

Hi @SunMarc -- thank you as well for the awesome package and for taking a look at this! I think you may be right about #32135 being the starting point for this issue (see more on this below, but TL:DR, 4.43.0 is where the issue truly began, and this is consistent timing with that PR being merged).

Here's a quick reproducible example adapted from the AutoAWQ examples (specifically from here):

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'Qwen/Qwen2.5-32B-Instruct'
quant_path = 'model-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
    model_path, low_cpu_mem_usage=True, use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')

Using the current as of writing transformers-4.45.1 this will produce the following error assuming CPU offloading comes into play (for a full precision 32B that's likely, but you may have more VRAM than me... 😄):

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

This issue has been noted by other users in both AutoAWQ (casper-hansen/AutoAWQ#558) and AutoGPTQ (AutoGPTQ/AutoGPTQ#729); some of these users also suggest the PR I've proposed here (i.e., #33742) resolves it for them.

Note that for Qwen2.5, a relatively new version of transformers is needed for the model generally, but to help trace this down, it's worth noting that you can revert transformers back to 4.42.4 to get the broader issue to go away as mentioned within this thread/post (although then you'll need to revert to an older model like Llama 3 70B for things to work generally, since Llama 3.1 for example needs a newer transformers version [i.e., 4.43.0] or the quantization will fail due to a separate RoPE scaling error).

Thanks again for looking into this!

PS. Feel free to use a different model besides Qwen/Qwen2.5-32B-Instruct -- As noted above, I just went with this since it: 1) didn't need a HF token to download; and 2) is big enough to require CPU offloading for quantization on my setup (48GB VRAM -- specifically an A6000 Ada).

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Oct 4, 2024

Thx for sharing the reproducer, it helped me a lot! The issue that we see is due to rope being computed in the model and not in the decoder model. For example in the llama modeling code, we have:

        # create position embeddings to be shared across the decoder layers
        # was computed in the decoder layer before
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

Hence, due to AWQ code that moves manually the modules, we need to move the rotary embedding layers to the appropriate device, just like how they did it for the embedding layer.

In init_quant function:

        self.awq_model.move_embed(self.model, best_device)
        # Note: added this line but each model is different. We should create a move_rope method
        self.awq_model.model.model.rotary_emb.to(best_device)

       ....
        self.awq_model.move_embed(self.model, "cpu")
        # added line
        self.awq_model.model.model.rotary_emb.to("cpu")

I tested with this change and it runs fine.
So, I think the fix would should live in AWQ and not in transformers. The same happens to AutoGPTQ because they are based on the same logic. Would you like to do the PR on AWQ repository @davedgd ?

@davedgd
Copy link
Copy Markdown
Author

davedgd commented Oct 4, 2024

Thx for sharing the reproducer, it helped me a lot! The issue that we see is due to rope being computed in the model and not in the decoder model. For example in the llama modeling code, we have:

        # create position embeddings to be shared across the decoder layers
        # was computed in the decoder layer before
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

Hence, due to AWQ code that moves manually the modules, we need to move the rotary embedding layers to the appropriate device, just like how they did it for the embedding layer.

In init_quant function:

        self.awq_model.move_embed(self.model, best_device)
        # Note: added this line but each model is different. We should create a move_rope method
        self.awq_model.model.model.rotary_emb.to(best_device)

       ....
        self.awq_model.move_embed(self.model, "cpu")
        # added line
        self.awq_model.model.model.rotary_emb.to("cpu")

I tested with this change and it runs fine. So, I think the fix would should live in AWQ and not in transformers. The same happens to AutoGPTQ because they are based on the same logic. Would you like to do the PR on AWQ repository @davedgd ?

Hi @SunMarc -- thanks for doing this thorough investigation; I'm really grateful! Yes, I can absolutely submit this PR to AutoAWQ (and likely AutoGPTQ as well as you suggested). I'll take care of this shortly. I think we're good to close this PR in the meantime -- I'll reference it when I submit to the other packages.

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Oct 4, 2024

Thanks for taking care of that! Feel free to ask any questions!

@trevor-m
Copy link
Copy Markdown

trevor-m commented Oct 8, 2024

@SunMarc I think this is an actual bug in transformers that should be fixed instead of requiring users to change their code. Could you take a look at my comment on the issue? Thanks! #32420 (comment)

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Oct 8, 2024

Thanks for investigating @trevor-m ! I've answered your comment!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Recent changes is causing "found at least two devices"

5 participants