Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions examples/profiling/profiling_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import gc
import inspect
import logging
import os
from dataclasses import dataclass, field
Expand Down Expand Up @@ -28,7 +29,18 @@ def annotate_pipeline(pipe):
"""Apply profiler annotations to key pipeline methods.

Monkey-patches bound methods so they appear as named spans in the trace.
Non-invasive — no source modifications required.
Non-invasive -- no source modifications required.

Sub-component methods are patched at the **class level** (via
`setattr(type(component), ...)`) rather than on the instance. This
ensures Python's descriptor protocol re-binds the wrapper to whichever
instance accesses it, so shallow-copied components (e.g. the duplicated
audio_scheduler inside LTX2) call their own logic rather than the
original instance's.

Returns:
A restore callable that undoes all patches and restores the
original method definitions, making the annotation transient.
"""
annotations = [
("transformer", "forward", "transformer_forward"),
Expand All @@ -37,6 +49,8 @@ def annotate_pipeline(pipe):
("scheduler", "step", "scheduler_step"),
]

saved = [] # (target, method_name, original_value, is_class_patch)

# Annotate sub-component methods
for component_name, method_name, label in annotations:
component = getattr(pipe, component_name, None)
Expand All @@ -45,12 +59,37 @@ def annotate_pipeline(pipe):
method = getattr(component, method_name, None)
if method is None:
continue
setattr(component, method_name, annotate(method, label))
if inspect.ismethod(method):
# Patch at the class level so the descriptor protocol correctly
# re-binds the wrapper to whichever instance accesses it. This
# prevents instance-isolation bugs when a component is
# shallow-copied. The original class attribute is saved so the
# patch can be reversed when restore() is called.
cls = type(component)
original = cls.__dict__.get(method_name)
setattr(cls, method_name, annotate(method.__func__, label))
saved.append((cls, method_name, original, True))
else:
original = component.__dict__.get(method_name)
setattr(component, method_name, annotate(method, label))
saved.append((component, method_name, original, False))

# Annotate pipeline-level methods
if hasattr(pipe, "encode_prompt"):
original = pipe.__dict__.get("encode_prompt")
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
saved.append((pipe, "encode_prompt", original, False))

def restore():
"""Undo all patches applied by annotate_pipeline, restoring originals."""
for target, name, original, is_class in saved:
if original is None:
if name in vars(target):
delattr(target, name)
else:
setattr(target, name, original)

return restore

def flush():
gc.collect()
Expand Down Expand Up @@ -130,7 +169,9 @@ def setup_pipeline(self, annotate=True):
pipe.set_progress_bar_config(disable=True)

if annotate:
annotate_pipeline(pipe)
self._restore_annotations = annotate_pipeline(pipe)
else:
self._restore_annotations = None
return pipe

def run(self):
Expand Down Expand Up @@ -176,7 +217,9 @@ def run(self):
)
)

# Cleanup
# Cleanup -- restore patched methods so class-level patches don't persist
if self._restore_annotations is not None:
self._restore_annotations()
pipe.to("cpu")
del pipe
flush()
Expand Down
Loading