diff --git a/backends/xnnpack/passes/__init__.py b/backends/xnnpack/passes/__init__.py index c4374c006a1..9cecf5ea482 100644 --- a/backends/xnnpack/passes/__init__.py +++ b/backends/xnnpack/passes/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +from typing import List, Optional, Type from executorch.backends.xnnpack.passes.channels_last_tagged_reshape_pass import ( ChannelsLastTaggedReshapePass, @@ -29,7 +29,9 @@ class XNNPACKPassManager: def __init__( - self, exported_program: ExportedProgram, passes: Optional[List[PassType]] = None + self, + exported_program: ExportedProgram, + passes: Optional[List[Type[PassType]]] = None, ) -> None: """ A helper class to run multiple XNNPack passes on a program diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 680c508ae9a..a2febf572ac 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. import copy +import sys from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import torch import torch._export as export @@ -25,6 +26,7 @@ from executorch.exir.backend.backend_api import to_backend, validation_disabled from executorch.exir.backend.partitioner import Partitioner from executorch.exir.passes.spec_prop_pass import SpecPropPass +from executorch.exir.print_program import pretty_print, print_program from executorch.extension.pybindings.portable_lib import ( # @manual _load_for_executorch_from_buffer, @@ -69,6 +71,39 @@ def graph_module(self): """ pass + def run_artifact(self, inputs): + """ + Returns the output of calling the artifact generated by this stage with inputs + """ + return self.artifact(*inputs) + + # Debug Tools for stages + def artifact_str(self): + """ + Return string printable artifact for this stage + """ + if isinstance(self.artifact, ExirExportedProgram): + return self.artifact.exported_program + return self.artifact + + def stage_banner(self): + """ + Returns banner string for this stage + """ + return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n" + + def dump_artifact(self, path_to_dump: Optional[str]): + """ + Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal + """ + if path_to_dump: + with open(path_to_dump, "a") as fp: + fp.write(str(self.stage_banner() + "\n")) + fp.write(str(self.artifact_str())) + else: + print(self.stage_banner() + "\n") + print(self.artifact_str()) + _stages_: Dict[str, Stage] = {} @@ -158,7 +193,7 @@ def graph_module(self) -> str: @register_stage class RunPasses(Stage): - def __init__(self, pass_list: Optional[List[PassType]] = None): + def __init__(self, pass_list: Optional[List[Type[PassType]]] = None): self.pass_list = pass_list self.edge_dialect_program = None @@ -207,31 +242,43 @@ def __init__( self.config = config or ExecutorchBackendConfig( passes=[SpecPropPass()], ) - self.exported_program = None + self.executorch_program = None def run(self, artifact: ExirExportedProgram, inputs=None): - self.exported_program = artifact.to_executorch(self.config) + self.executorch_program = artifact.to_executorch(self.config) @property def artifact(self) -> ExecutorchProgram: - return self.exported_program + return self.executorch_program @property def graph_module(self) -> str: - return self.exported_program.graph_module + return self.executorch_program.graph_module + + def dump_artifact(self, path_to_dump: Optional[str]): + """ + dump_artifact is overriden to dump the serialized program + """ + original_stdout = sys.stdout + + sys.stdout = open(path_to_dump, "a") if path_to_dump else sys.stdout + print(self.stage_banner() + "\n") + pretty_print(self.artifact.program) + print_program( + self.artifact.program, + show_meminfo=True, + mark_dynamic_shape_tensor=True, + ) + sys.stdout = original_stdout @register_stage class Serialize(Stage): - def __init__(self, filename: Optional[str] = None): + def __init__(self): self.buffer = None - self.filename = filename def run(self, artifact: ExecutorchProgram, inputs=None) -> None: self.buffer = artifact.buffer - if self.filename is not None: - with open(self.filename, "wb") as f: - f.write(self.buffer) @property def artifact(self) -> bytes: @@ -241,6 +288,24 @@ def artifact(self) -> bytes: def graph_module(self) -> None: return None + def run_artifact(self, inputs): + inputs_flattened, _ = tree_flatten(inputs) + executorch_module = _load_for_executorch_from_buffer(self.buffer) + executorch_output = copy.deepcopy( + executorch_module.run_method("forward", tuple(inputs_flattened)) + ) + return executorch_output + + def dump_artifact(self, path_to_dump: Optional[str]): + """ + dump_artifact is overridden to dump the serialized bytes into pte file + """ + if not path_to_dump: + raise RuntimeError("path_to_dump file not provided") + else: + with open(path_to_dump, "wb") as f: + f.write(self.artifact) + class Tester: def __init__( @@ -248,7 +313,7 @@ def __init__( module: torch.nn.Module, inputs: Tuple[torch.Tensor], ): - self.module = module + self.original_module = module self.inputs = inputs self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) self.pipeline = { @@ -276,8 +341,8 @@ def __init__( # Reference output from Eager mode self.reference_output = None - # Output by running a serialized/lowered module on ET - self.executorch_output = None + # Artifact output from stage + self.stage_output = None @staticmethod def _stage_name(stage) -> str: @@ -288,7 +353,7 @@ def _pre(self, stage): name: str = self._stage_name(stage) assert isinstance(name, str) and name in self.stages and not self.stages[name] - last_artifact = self.module + last_artifact = self.original_module if self.cur: assert self.cur in self.pipeline, f"Invalid state: {self.cur}" allowed_next_stages = self.pipeline[self.cur] @@ -332,6 +397,11 @@ def serialize(self, serialize_stage: Optional[Serialize] = None): return self._run_stage(serialize_stage or Serialize()) # Util functions + def dump_artifact(self, path: Optional[str] = None, stage: Optional[str] = None): + stage = stage or self.cur + self.stages[stage].dump_artifact(path) + return self + def get_artifact(self, stage: Optional[str] = None): stage = stage or self.cur return self.stages[stage].artifact @@ -354,18 +424,19 @@ def check_count(self, input: Dict[Any, int]): ) return self - def run_method(self, method="forward"): - # Reference - delegated_module = self.get_artifact(self._stage_name(Partition)) - self.reference_output = delegated_module(*self.inputs) - - # ExecuTorch - inputs_flattened, _ = tree_flatten(self.inputs) - serialized_buffer = self.get_artifact(self._stage_name(Serialize)) - executorch_module = _load_for_executorch_from_buffer(serialized_buffer) - self.executorch_output = copy.deepcopy( - executorch_module.run_method(method, tuple(inputs_flattened)) + def run_method( + self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None + ): + inputs_to_run = inputs or self.inputs + # Reference Output + self.reference_output = self.stages[self._stage_name(Export)].run_artifact( + inputs_to_run ) + + # Output from running artifact at stage + stage = stage or self.cur + self.stage_output = self.stages[stage].run_artifact(inputs_to_run) + return self @staticmethod @@ -377,25 +448,31 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): relative tolerance is 1e-3. """ - # Compare the result from executor and eager mode direclty - if isinstance(ref_output, tuple) or isinstance(ref_output, list): - # Multiple outputs executor always returns tuple, even if there is one output - assert len(ref_output) == len(model_output) - for i in range(len(ref_output)): - assert torch.allclose( - model_output[i], - ref_output[i], - atol=atol, - rtol=rtol, - ) - else: - # If one output, eager returns tensor while executor returns a tuple(tensor) of size 1 - assert torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) + # Multiple outputs executor always returns tuple, even if there is one output + assert len(ref_output) == len(model_output) + for i in range(len(ref_output)): + assert torch.allclose( + model_output[i], + ref_output[i], + atol=atol, + rtol=rtol, + ) def compare_outputs(self, atol=1e-03, rtol=1e-03): + """ + Compares the original of the original nn module with the output of the generated artifact. + This requres calling run_method before calling compare_outputs. As that runs the generated + artifact on the sample inputs and sets the stage output to be compared against the reference + """ assert self.reference_output is not None - assert self.executorch_output is not None + assert self.stage_output is not None + + # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor + if isinstance(self.reference_output, torch.Tensor): + self.reference_output = (self.reference_output,) + if isinstance(self.stage_output, torch.Tensor): + self.stage_output = (self.stage_output,) self._assert_outputs_equal( - self.executorch_output, self.reference_output, atol=atol, rtol=rtol + self.stage_output, self.reference_output, atol=atol, rtol=rtol ) return self