Skip to content

Torch.compile fail during inference with meta-llama/Meta-Llama-3.1-8B-Instruct #34604

@prasiyer

Description

@prasiyer

System Info

  • transformers version: 4.43.3
  • Platform: Linux-5.15.0-1074-azure-x86_64-with-glibc2.31
  • Python version: 3.11.9
  • Huggingface_hub version: 0.23.1
  • Safetensors version: 0.4.3
  • Accelerate version: 0.31.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: using device_map = "auto" in AutoModelForCausalLM.from_pretrained
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

@gante , @ArthurZucker
While using torch.compile(), I get the following error. I have included the sample code in the "Steps to reproduce"

Error:
Traceback (most recent call last):
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/route_utils.py", line 276, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/blocks.py", line 1923, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/blocks.py", line 1506, in call_function
    prediction = await fn(*processed_input)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/utils.py", line 785, in async_wrapper
    response = await f(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/chat_interface.py", line 607, in _submit_fn
    response = await anyio.to_thread.run_sync(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2134, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 851, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vp899/projects/Agent_System/Code/Agent_Launch_UI_v2_Experiments.py", line 253, in contract_analyst_chat
    outputs = model.generate(input_ids, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 482, in transform
    tracer = InstructionTranslator(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2085, in __init__
    self._throw_if_in_functorch()
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2126, in _throw_if_in_functorch
    eager = torch._dynamo.lookup_backend("eager")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 58, in lookup_backend
    _lazy_import()
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 91, in _lazy_import
    import_submodule(backends)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1866, in import_submodule
    importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
  File "/anaconda/envs/pi2_py311/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/cudagraphs.py", line 10, in <module>
    from torch._inductor.cudagraph_trees import cudagraphify_impl
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 71, in <module>
    from torch._inductor.compile_fx import (
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 57, in <module>
    from .fx_passes.joint_graph import joint_graph_passes
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 12, in <module>
    from ..pattern_matcher import (
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py", line 46, in <module>
    from .lowering import fallback_node_due_to_unsupported_type
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 6002, in <module>
    import_submodule(kernel)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1866, in import_submodule
    importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
  File "/anaconda/envs/pi2_py311/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py", line 155, in <module>
    flex_attention_template = TritonTemplate(
                              ^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 453, in __init__
    self.template = self._template_from_string(source)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/codegen/common.py", line 1720, in _template_from_string
    return env.from_string(source)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 1108, in from_string
    return cls.from_code(self, self.compile(source), gs, None)
                               ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 768, in compile
    self.handle_exception(source=source_hint)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 939, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<unknown>", line 104, in template
torch._dynamo.exc.InternalTorchDynamoError: No filter named 'indent_except_first'.


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, token = llama31_hf_token)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", token = llama31_hf_token, attn_implementation="flash_attention_2",)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

...
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
outputs = model.generate(input_ids, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9)

Expected behavior

Model should compile and model.generate should yield the answer

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions