diff --git a/examples/profiling/profiling_utils.py b/examples/profiling/profiling_utils.py index 1c7d59d42fde..5a2ffd6a7c8b 100644 --- a/examples/profiling/profiling_utils.py +++ b/examples/profiling/profiling_utils.py @@ -1,5 +1,6 @@ import functools import gc +import inspect import logging import os from dataclasses import dataclass, field @@ -28,7 +29,18 @@ def annotate_pipeline(pipe): """Apply profiler annotations to key pipeline methods. Monkey-patches bound methods so they appear as named spans in the trace. - Non-invasive — no source modifications required. + Non-invasive -- no source modifications required. + + Sub-component methods are patched at the **class level** (via + `setattr(type(component), ...)`) rather than on the instance. This + ensures Python's descriptor protocol re-binds the wrapper to whichever + instance accesses it, so shallow-copied components (e.g. the duplicated + audio_scheduler inside LTX2) call their own logic rather than the + original instance's. + + Returns: + A restore callable that undoes all patches and restores the + original method definitions, making the annotation transient. """ annotations = [ ("transformer", "forward", "transformer_forward"), @@ -37,6 +49,8 @@ def annotate_pipeline(pipe): ("scheduler", "step", "scheduler_step"), ] + saved = [] # (target, method_name, original_value, is_class_patch) + # Annotate sub-component methods for component_name, method_name, label in annotations: component = getattr(pipe, component_name, None) @@ -45,12 +59,37 @@ def annotate_pipeline(pipe): method = getattr(component, method_name, None) if method is None: continue - setattr(component, method_name, annotate(method, label)) + if inspect.ismethod(method): + # Patch at the class level so the descriptor protocol correctly + # re-binds the wrapper to whichever instance accesses it. This + # prevents instance-isolation bugs when a component is + # shallow-copied. The original class attribute is saved so the + # patch can be reversed when restore() is called. + cls = type(component) + original = cls.__dict__.get(method_name) + setattr(cls, method_name, annotate(method.__func__, label)) + saved.append((cls, method_name, original, True)) + else: + original = component.__dict__.get(method_name) + setattr(component, method_name, annotate(method, label)) + saved.append((component, method_name, original, False)) # Annotate pipeline-level methods if hasattr(pipe, "encode_prompt"): + original = pipe.__dict__.get("encode_prompt") pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt") + saved.append((pipe, "encode_prompt", original, False)) + + def restore(): + """Undo all patches applied by annotate_pipeline, restoring originals.""" + for target, name, original, is_class in saved: + if original is None: + if name in vars(target): + delattr(target, name) + else: + setattr(target, name, original) + return restore def flush(): gc.collect() @@ -130,7 +169,9 @@ def setup_pipeline(self, annotate=True): pipe.set_progress_bar_config(disable=True) if annotate: - annotate_pipeline(pipe) + self._restore_annotations = annotate_pipeline(pipe) + else: + self._restore_annotations = None return pipe def run(self): @@ -176,7 +217,9 @@ def run(self): ) ) - # Cleanup + # Cleanup -- restore patched methods so class-level patches don't persist + if self._restore_annotations is not None: + self._restore_annotations() pipe.to("cpu") del pipe flush()