Skip to content

[Tracer] Fail to register _flash_attention_forward as leaf function #41

@zyeric

Description

@zyeric

Environment

torch                          2.6.0+cu124
flash_attn                     2.7.4.post1
transformers                   4.51.3

Reproduce

add a file with name transformers.py under nnscaler/graph/parser/external/, the content is

import copy
import logging
import string

from nnscaler.graph.function.dimops import ShapeAnno, OpAnno
from nnscaler.graph import parser
from nnscaler.ir import IRTensor

_logger = logging.getLogger(__name__)

try:

    def flash_attention_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str:
        if query_states.shape[2] != key_states.shape[2]:
            assert query_states.shape[2] % key_states.shape[2] == 0
            group_size = query_states.shape[2] // key_states.shape[2]
            assert query_states.shape[2] == value_states.shape[2] * group_size
            q_anno = f'(group_num {group_size})'
            kv_anno = 'group_num'
        else:
            q_anno = kv_anno = 'num_heads'
        if isinstance(attention_mask, IRTensor):
            return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^'
        else:
            return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^'

    from transformers.modeling_flash_attention_utils import _flash_attention_forward
    parser.register(flash_attention_anno)(_flash_attention_forward)
    _logger.info("FA registered successfully.")

except Exception as e:
    _logger.warning(f"FA will not be registered: {e}")

add from .transformers import * to nnscaler/graph/parser/external/__init__.py

test script

from typing import Dict, Any, Tuple, Optional, Callable, Union
import logging

import torch
from transformers import AutoConfig, AutoModelForCausalLM

import nnscaler
from nnscaler import parallelize, ComputeConfig
from nnscaler.policies import pas_dp
from nnscaler.graph.parser.register import CustomizedOps


def main():
    torch_dtype = torch.float32

    # Actor:
    # input:
    #   - input_ids: torch.Tensor, shape [batch_size, seq_len]
    #   - attention_mask: torch.Tensor, shape [batch_size, seq_len]
    #   - position_ids: torch.Tensor, shape [batch_size, seq_len]
    # output:
    #   - logits: torch.Tensor, shape [batch_size, seq_len, vocab_size]
    class ActorWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, input_ids, attention_mask, position_ids, cache_position):
            if torch_dtype == torch.float32:
                ctx = torch.autocast(device_type='cuda', dtype=torch.bfloat16)
            else:
                import contextlib
                ctx = contextlib.nullcontext()
            with ctx:
                output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=False)
                return output.logits

    local_path = 'Qwen/Qwen2.5-3B-Instruct'
    hf_config = AutoConfig.from_pretrained(
        pretrained_model_name_or_path=local_path,
        trust_remote_code=False,
        attn_implementation="flash_attention_2",
    )
    actor_module = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=local_path,
        torch_dtype=torch_dtype,
        config=hf_config,
    )

    actor_wrapper = ActorWrapper(actor_module)

    # should register successfully
    assert 'transformers.modeling_flash_attention_utils._flash_attention_forward' in CustomizedOps.kOpMap
    print(CustomizedOps.kOpMap)

    bsz = 1
    seq_len = 4096
    dummy_input = {
        "input_ids": torch.randint(0, 1000, (bsz, seq_len), dtype=torch.int64),
        "attention_mask": torch.ones((bsz, seq_len), dtype=torch.int64),
        "position_ids": torch.arange(seq_len).expand(bsz, -1).to(torch.int64),
        "cache_position": torch.arange(seq_len).to(torch.int64),
    }
    compute_config = ComputeConfig(
        plan_ngpus=1,
        runtime_ngpus=1,
        constant_folding=True,
        use_zero=1,
        inference_only=False,
    )
    parallelize(
        module_or_module_class=actor_wrapper,
        dummy_forward_args=dummy_input,
        pas_policy=pas_dp,
        compute_config=compute_config,
        load_module=False,
    )


if __name__ == "__main__":
    nnscaler.utils.set_default_logger_level(level=logging.INFO)
    main()

test command

python test.py

If you execute the command, you will see the following output, which means it traced into the registered function

2025-07-29 22:28:26 | WARNING | nnscaler.graph.parser.parser | Set python runtime function: flash_attn.flash_attn_interface.FlashAttnFunc.apply
2025-07-29 22:28:26 | WARNING | nnscaler.graph.parser.parser | non register python runtime function flash_attn.flash_attn_interface.FlashAttnFunc.apply has a non constant input: [FullTensor(id=172, shape=(1, 4096, 16, 128), req_grad=True), FullTensor(id=174, shape=(1, 4096, 2, 128), req_grad=True), FullTensor(id=176, shape=(1, 4096, 2, 128), req_grad=True), 0.0, 0.08838834764831845, True, (-1, -1), 0.0, None, False, False, True], You can register it as a customized function using nnscaler.register_op to remove this warning

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions