Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/xnnpack/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
from typing import List, Optional, Type

from executorch.backends.xnnpack.passes.channels_last_tagged_reshape_pass import (
ChannelsLastTaggedReshapePass,
Expand All @@ -29,7 +29,9 @@

class XNNPACKPassManager:
def __init__(
self, exported_program: ExportedProgram, passes: Optional[List[PassType]] = None
self,
exported_program: ExportedProgram,
passes: Optional[List[Type[PassType]]] = None,
) -> None:
"""
A helper class to run multiple XNNPack passes on a program
Expand Down
161 changes: 119 additions & 42 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# LICENSE file in the root directory of this source tree.

import copy
import sys
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch._export as export
Expand All @@ -25,6 +26,7 @@
from executorch.exir.backend.backend_api import to_backend, validation_disabled
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.print_program import pretty_print, print_program

from executorch.extension.pybindings.portable_lib import ( # @manual
_load_for_executorch_from_buffer,
Expand Down Expand Up @@ -69,6 +71,39 @@ def graph_module(self):
"""
pass

def run_artifact(self, inputs):
"""
Returns the output of calling the artifact generated by this stage with inputs
"""
return self.artifact(*inputs)

# Debug Tools for stages
def artifact_str(self):
"""
Return string printable artifact for this stage
"""
if isinstance(self.artifact, ExirExportedProgram):
return self.artifact.exported_program
return self.artifact

def stage_banner(self):
"""
Returns banner string for this stage
"""
return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n"

def dump_artifact(self, path_to_dump: Optional[str]):
"""
Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal
"""
if path_to_dump:
with open(path_to_dump, "a") as fp:
fp.write(str(self.stage_banner() + "\n"))
fp.write(str(self.artifact_str()))
else:
print(self.stage_banner() + "\n")
print(self.artifact_str())


_stages_: Dict[str, Stage] = {}

Expand Down Expand Up @@ -158,7 +193,7 @@ def graph_module(self) -> str:

@register_stage
class RunPasses(Stage):
def __init__(self, pass_list: Optional[List[PassType]] = None):
def __init__(self, pass_list: Optional[List[Type[PassType]]] = None):
self.pass_list = pass_list
self.edge_dialect_program = None

Expand Down Expand Up @@ -207,31 +242,43 @@ def __init__(
self.config = config or ExecutorchBackendConfig(
passes=[SpecPropPass()],
)
self.exported_program = None
self.executorch_program = None

def run(self, artifact: ExirExportedProgram, inputs=None):
self.exported_program = artifact.to_executorch(self.config)
self.executorch_program = artifact.to_executorch(self.config)

@property
def artifact(self) -> ExecutorchProgram:
return self.exported_program
return self.executorch_program

@property
def graph_module(self) -> str:
return self.exported_program.graph_module
return self.executorch_program.graph_module

def dump_artifact(self, path_to_dump: Optional[str]):
"""
dump_artifact is overriden to dump the serialized program
"""
original_stdout = sys.stdout

sys.stdout = open(path_to_dump, "a") if path_to_dump else sys.stdout
print(self.stage_banner() + "\n")
pretty_print(self.artifact.program)
print_program(
self.artifact.program,
show_meminfo=True,
mark_dynamic_shape_tensor=True,
)
sys.stdout = original_stdout


@register_stage
class Serialize(Stage):
def __init__(self, filename: Optional[str] = None):
def __init__(self):
self.buffer = None
self.filename = filename

def run(self, artifact: ExecutorchProgram, inputs=None) -> None:
self.buffer = artifact.buffer
if self.filename is not None:
with open(self.filename, "wb") as f:
f.write(self.buffer)

@property
def artifact(self) -> bytes:
Expand All @@ -241,14 +288,32 @@ def artifact(self) -> bytes:
def graph_module(self) -> None:
return None

def run_artifact(self, inputs):
inputs_flattened, _ = tree_flatten(inputs)
executorch_module = _load_for_executorch_from_buffer(self.buffer)
executorch_output = copy.deepcopy(
executorch_module.run_method("forward", tuple(inputs_flattened))
)
return executorch_output

def dump_artifact(self, path_to_dump: Optional[str]):
"""
dump_artifact is overridden to dump the serialized bytes into pte file
"""
if not path_to_dump:
raise RuntimeError("path_to_dump file not provided")
else:
with open(path_to_dump, "wb") as f:
f.write(self.artifact)


class Tester:
def __init__(
self,
module: torch.nn.Module,
inputs: Tuple[torch.Tensor],
):
self.module = module
self.original_module = module
self.inputs = inputs
self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys()))
self.pipeline = {
Expand Down Expand Up @@ -276,8 +341,8 @@ def __init__(
# Reference output from Eager mode
self.reference_output = None

# Output by running a serialized/lowered module on ET
self.executorch_output = None
# Artifact output from stage
self.stage_output = None

@staticmethod
def _stage_name(stage) -> str:
Expand All @@ -288,7 +353,7 @@ def _pre(self, stage):
name: str = self._stage_name(stage)
assert isinstance(name, str) and name in self.stages and not self.stages[name]

last_artifact = self.module
last_artifact = self.original_module
if self.cur:
assert self.cur in self.pipeline, f"Invalid state: {self.cur}"
allowed_next_stages = self.pipeline[self.cur]
Expand Down Expand Up @@ -332,6 +397,11 @@ def serialize(self, serialize_stage: Optional[Serialize] = None):
return self._run_stage(serialize_stage or Serialize())

# Util functions
def dump_artifact(self, path: Optional[str] = None, stage: Optional[str] = None):
stage = stage or self.cur
self.stages[stage].dump_artifact(path)
return self

def get_artifact(self, stage: Optional[str] = None):
stage = stage or self.cur
return self.stages[stage].artifact
Expand All @@ -354,18 +424,19 @@ def check_count(self, input: Dict[Any, int]):
)
return self

def run_method(self, method="forward"):
# Reference
delegated_module = self.get_artifact(self._stage_name(Partition))
self.reference_output = delegated_module(*self.inputs)

# ExecuTorch
inputs_flattened, _ = tree_flatten(self.inputs)
serialized_buffer = self.get_artifact(self._stage_name(Serialize))
executorch_module = _load_for_executorch_from_buffer(serialized_buffer)
self.executorch_output = copy.deepcopy(
executorch_module.run_method(method, tuple(inputs_flattened))
def run_method(
self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None
):
inputs_to_run = inputs or self.inputs
# Reference Output
self.reference_output = self.stages[self._stage_name(Export)].run_artifact(
inputs_to_run
)

# Output from running artifact at stage
stage = stage or self.cur
self.stage_output = self.stages[stage].run_artifact(inputs_to_run)

return self

@staticmethod
Expand All @@ -377,25 +448,31 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
relative tolerance is 1e-3.
"""

# Compare the result from executor and eager mode direclty
if isinstance(ref_output, tuple) or isinstance(ref_output, list):
# Multiple outputs executor always returns tuple, even if there is one output
assert len(ref_output) == len(model_output)
for i in range(len(ref_output)):
assert torch.allclose(
model_output[i],
ref_output[i],
atol=atol,
rtol=rtol,
)
else:
# If one output, eager returns tensor while executor returns a tuple(tensor) of size 1
assert torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
# Multiple outputs executor always returns tuple, even if there is one output
assert len(ref_output) == len(model_output)
for i in range(len(ref_output)):
assert torch.allclose(
model_output[i],
ref_output[i],
atol=atol,
rtol=rtol,
)

def compare_outputs(self, atol=1e-03, rtol=1e-03):
"""
Compares the original of the original nn module with the output of the generated artifact.
This requres calling run_method before calling compare_outputs. As that runs the generated
artifact on the sample inputs and sets the stage output to be compared against the reference
"""
assert self.reference_output is not None
assert self.executorch_output is not None
assert self.stage_output is not None

# Wrap both outputs as tuple, since executor output is always a tuple even if single tensor
if isinstance(self.reference_output, torch.Tensor):
self.reference_output = (self.reference_output,)
if isinstance(self.stage_output, torch.Tensor):
self.stage_output = (self.stage_output,)
self._assert_outputs_equal(
self.executorch_output, self.reference_output, atol=atol, rtol=rtol
self.stage_output, self.reference_output, atol=atol, rtol=rtol
)
return self