Skip to content

Shape mismatch when generating with multiple processes #32603

@ojh31

Description

@ojh31

System Info

  • transformers version: 4.42.4
  • Platform: Linux-5.15.0-106-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.29.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: yes (accelerate/FSDP)
  • Using GPU in script?: yes
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@gante @SunMarc @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run accelerate launch --config_file=accelerate_config.yaml foo.py

foo.py:

from accelerate import Accelerator
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoModelForCausalLM, LlamaTokenizer

NAME = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(
    NAME, 
    padding_side="left", 
)
accelerator = Accelerator()
model = AutoModelForCausalLM.from_pretrained(NAME)
model = accelerator.prepare(model)

paired_input_ids = torch.tensor([
    [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             1, 29961, 25580, 29962,   887,   526,   385,  7395,  6993,  1788,
         29889, 13866,   366,   674,   367,  5429,   278,  1788,  4800, 29889,
           450,  1404,   674,   769,  3896,   263,  4800, 29889,   960,   278,
          1404,  4800,  7087,   278,  1788,  4800, 29892,   736, 18016, 13566,
          3352, 29889,   960,   278,  1404,  4800,   947,   451,  1993,   278,
          1788,  4800, 29892,   736,   360,  1430, 29902,  3352, 29889, 18076,
           487,  3099,  1156,   278,  1404,  4800, 29889,    13,    13,  3924,
          4800, 29901, 15040,    13,  2659,  4800, 29901, 15040,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962],
            [    2,     2,     2,     2,     2,     2,     1, 29961, 25580, 29962,
           887,   526,   385,  7395,  6993,  1788, 29889, 13866,   366,   674,
           367,  5429,   278,  1788,  4800, 29889,   450,  1404,   674,   769,
          3896,   263,  4800, 29889,   960,   278,  1404,  4800,  7087,   278,
          1788,  4800, 29892,   736, 18016, 13566,  3352, 29889,   960,   278,
          1404,  4800,   947,   451,  1993,   278,  1788,  4800, 29892,   736,
           360,  1430, 29902,  3352, 29889, 18076,   487,  3099,  1156,   278,
          1404,  4800, 29889,    13,    13,  3924,  4800, 29901,  1757, 10582,
           284,    13,  2659,  4800, 29901,  1757, 10582,   284,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962]
        ]
)
paired_attention_mask = torch.tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
    [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]
])

paired_dataset = TensorDataset(paired_input_ids, paired_attention_mask)

dataloader = DataLoader(
    dataset=paired_dataset,
    batch_size=1,  # Process one pair at a time
    shuffle=False,
)
dataloader = accelerator.prepare(dataloader)


for batch_input_ids, batch_attention_mask in dataloader:
    with torch.no_grad():
        model.forward(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
    with FSDP.summon_full_params(model, recurse=False):
        outputs = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask, 
            tokenizer=tokenizer,
            synced_gpus=True,
        )

accelerate_config.yaml:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: "no"
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: "no"
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

Should generate text output, but instead throws error

The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 648, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 718, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 978, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/robust-llm/pairs.py", line 66, in <module>
    outputs = model.generate(
RuntimeError: The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]

Hypothesis:
In transformers/generation/utils.py::GenerationMixin_sample(), during the while self._has_unfinished_sequences() loop, we continue if synced_gpus and this_peer_finished. This results in not skipping the concatenation of next_tokens to input_ids. Whereas, we keep updating the past_key_value cache in transformers/models/llama/modeling_llama.py::LlamaSdpaAttention.forward(). Therefore, when one process finishes generation before the other, the finished process continues to expand the key-value cache but stops expanding the input tensors, leading to a shape mismatch. Maybe a simple fix would be to forcibly set past_key_value to None once this_peer_finished is set to True?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions