-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
Environment
torch 2.6.0+cu124
flash_attn 2.7.4.post1
transformers 4.51.3
Reproduce
add a file with name transformers.py under nnscaler/graph/parser/external/, the content is
import copy
import logging
import string
from nnscaler.graph.function.dimops import ShapeAnno, OpAnno
from nnscaler.graph import parser
from nnscaler.ir import IRTensor
_logger = logging.getLogger(__name__)
try:
def flash_attention_anno(query_states, key_states, value_states, attention_mask, *args, **kwargs) -> str:
if query_states.shape[2] != key_states.shape[2]:
assert query_states.shape[2] % key_states.shape[2] == 0
group_size = query_states.shape[2] // key_states.shape[2]
assert query_states.shape[2] == value_states.shape[2] * group_size
q_anno = f'(group_num {group_size})'
kv_anno = 'group_num'
else:
q_anno = kv_anno = 'num_heads'
if isinstance(attention_mask, IRTensor):
return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^, b l^ -> b l^ {q_anno} vd^'
else:
return f'b l^ {q_anno} hd^, b s^ {kv_anno} hd^, b s^ {kv_anno} vd^ -> b l^ {q_anno} vd^'
from transformers.modeling_flash_attention_utils import _flash_attention_forward
parser.register(flash_attention_anno)(_flash_attention_forward)
_logger.info("FA registered successfully.")
except Exception as e:
_logger.warning(f"FA will not be registered: {e}")add from .transformers import * to nnscaler/graph/parser/external/__init__.py
test script
from typing import Dict, Any, Tuple, Optional, Callable, Union
import logging
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import nnscaler
from nnscaler import parallelize, ComputeConfig
from nnscaler.policies import pas_dp
from nnscaler.graph.parser.register import CustomizedOps
def main():
torch_dtype = torch.float32
# Actor:
# input:
# - input_ids: torch.Tensor, shape [batch_size, seq_len]
# - attention_mask: torch.Tensor, shape [batch_size, seq_len]
# - position_ids: torch.Tensor, shape [batch_size, seq_len]
# output:
# - logits: torch.Tensor, shape [batch_size, seq_len, vocab_size]
class ActorWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, position_ids, cache_position):
if torch_dtype == torch.float32:
ctx = torch.autocast(device_type='cuda', dtype=torch.bfloat16)
else:
import contextlib
ctx = contextlib.nullcontext()
with ctx:
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=False)
return output.logits
local_path = 'Qwen/Qwen2.5-3B-Instruct'
hf_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=local_path,
trust_remote_code=False,
attn_implementation="flash_attention_2",
)
actor_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=hf_config,
)
actor_wrapper = ActorWrapper(actor_module)
# should register successfully
assert 'transformers.modeling_flash_attention_utils._flash_attention_forward' in CustomizedOps.kOpMap
print(CustomizedOps.kOpMap)
bsz = 1
seq_len = 4096
dummy_input = {
"input_ids": torch.randint(0, 1000, (bsz, seq_len), dtype=torch.int64),
"attention_mask": torch.ones((bsz, seq_len), dtype=torch.int64),
"position_ids": torch.arange(seq_len).expand(bsz, -1).to(torch.int64),
"cache_position": torch.arange(seq_len).to(torch.int64),
}
compute_config = ComputeConfig(
plan_ngpus=1,
runtime_ngpus=1,
constant_folding=True,
use_zero=1,
inference_only=False,
)
parallelize(
module_or_module_class=actor_wrapper,
dummy_forward_args=dummy_input,
pas_policy=pas_dp,
compute_config=compute_config,
load_module=False,
)
if __name__ == "__main__":
nnscaler.utils.set_default_logger_level(level=logging.INFO)
main()test command
python test.pyIf you execute the command, you will see the following output, which means it traced into the registered function
2025-07-29 22:28:26 | WARNING | nnscaler.graph.parser.parser | Set python runtime function: flash_attn.flash_attn_interface.FlashAttnFunc.apply
2025-07-29 22:28:26 | WARNING | nnscaler.graph.parser.parser | non register python runtime function flash_attn.flash_attn_interface.FlashAttnFunc.apply has a non constant input: [FullTensor(id=172, shape=(1, 4096, 16, 128), req_grad=True), FullTensor(id=174, shape=(1, 4096, 2, 128), req_grad=True), FullTensor(id=176, shape=(1, 4096, 2, 128), req_grad=True), 0.0, 0.08838834764831845, True, (-1, -1), 0.0, None, False, False, True], You can register it as a customized function using nnscaler.register_op to remove this warning
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels