Skip to content

add cudagraph annotation#2926

Draft
yushangdi wants to merge 2 commits intomainfrom
sy_cudagraph_annotation
Draft

add cudagraph annotation#2926
yushangdi wants to merge 2 commits intomainfrom
sy_cudagraph_annotation

Conversation

@yushangdi
Copy link
Copy Markdown

@yushangdi yushangdi commented Apr 10, 2026

Works with

Need pytorch/pytorch#179867

Result:

https://www.internalfb.com/intern/perfetto/open_trace/?manifold_path=perfetto_internal_traces%2Ftree%2Fshared_trace%2Fshangdiy_1d3dd617-60d4-4356-bb7e-46910c02da55_rank0_trace.annotated.json

"""Run graph_trainer for a few steps, capture annotations, and post-process trace."""
import json
import os
import pickle
import sys

os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")

import torch
torch.cuda.set_device(0)

sys.argv = [
    "train",
    "--module", "graph_trainer.llama3",
    "--config", "graph_trainer_llama3_debugmodel",
    "--compile.passes", "cudagraph",
    "--training.steps", "4",
    "--profiling.enable_profiling",
    "--profiling.save_traces_folder", "agent_space/traces",
    "--profiling.profile_freq", "4",
]

from torchtitan.train import main
main()

# Post-process
from torchtitan.experiments.graph_trainer.cudagraph import get_cudagraph_annotations
from torch.cuda._annotate_cuda_graph_trace import annotate_trace

annotations = get_cudagraph_annotations()
print(f"\n=== Captured {len(annotations)} annotated kernel(s) ===")
for tid, ann_list in list(annotations.items())[:20]:
    ann = ann_list[0]
    print(f"  toolsId=0x{tid:016x}  ->  {ann}")

import glob
traces = glob.glob("outputs/agent_space/traces/*/rank0_trace.json")
if not traces:
    traces = glob.glob("agent_space/traces/*/rank0_trace.json")
if traces:
    trace_path = max(traces, key=os.path.getmtime)
    print(f"\nPost-processing trace: {trace_path}")
    with open(trace_path) as f:
        trace = json.load(f)
    count = annotate_trace(trace, annotations)
    print(f"Annotated {count} kernel event(s)")
    print("\n--- Annotated kernel events (deduplicated) ---")
    seen = set()
    for e in trace["traceEvents"]:
        args = e.get("args", {})
        if args.get("graph node id", 0) == 0:
            continue
        comp = args.get("component", "")
        name = args.get("name", "")
        if not comp and not name:
            continue
        kname = e.get("name", "?")[:45]
        key = (kname, comp or name)
        if key in seen:
            continue
        seen.add(key)
        print(f"  {kname:45s}  component={comp:20s}  name={name}")
        if len(seen) >= 30:
            break
    out_path = trace_path.replace(".json", ".annotated.json")
    with open(out_path, "w") as f:
        json.dump(trace, f)
    print(f"\nSaved annotated trace to {out_path}")
else:
    print("No trace files found!")

Result:

image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 10, 2026
@yushangdi yushangdi force-pushed the sy_cudagraph_annotation branch from e53d997 to 6ec8f51 Compare April 10, 2026 18:42
@yushangdi yushangdi force-pushed the sy_cudagraph_annotation branch from 6ec8f51 to 5ed6f0d Compare April 10, 2026 19:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant