diff --git a/sdk/TARGETS b/sdk/TARGETS index 92827b151b0..10fc8859c14 100644 --- a/sdk/TARGETS +++ b/sdk/TARGETS @@ -1,28 +1,11 @@ -load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") load("@fbcode_macros//build_defs:python_library.bzl", "python_library") oncall("executorch") python_library( name = "lib", - srcs = [ - "__init__.py", - "lib.py", - ], - deps = [ - "//executorch/sdk/edir:et_schema", - "//executorch/sdk/etdb:etdb", - "//executorch/sdk/etrecord:etrecord", - ], -) - -python_binary( - name = "cli", - main_src = "lib.py", - par_style = "xar", + srcs = ["__init__.py"], deps = [ - "//executorch/sdk/edir:et_schema", - "//executorch/sdk/etdb:etdb", "//executorch/sdk/etrecord:etrecord", ], ) diff --git a/sdk/__init__.py b/sdk/__init__.py index ff41e832dbd..69786e88286 100644 --- a/sdk/__init__.py +++ b/sdk/__init__.py @@ -10,12 +10,8 @@ parse_etrecord, ) -from executorch.sdk.lib import debug_etrecord, debug_etrecord_path - __all__ = [ "ETRecord", "generate_etrecord", "parse_etrecord", - "debug_etrecord", - "debug_etrecord_path", ] diff --git a/sdk/etdb/TARGETS b/sdk/etdb/TARGETS index 45142992234..53684cb6876 100644 --- a/sdk/etdb/TARGETS +++ b/sdk/etdb/TARGETS @@ -2,30 +2,6 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") oncall("executorch") -python_library( - name = "row_schema", - srcs = [ - "row_schema.py", - ], - deps = [ - "//executorch/sdk/edir:base_schema", - "//executorch/sdk/edir:et_schema", - ], -) - -python_library( - name = "etdb", - srcs = [ - "etdb.py", - ], - deps = [ - "fbsource//third-party/pypi/tabulate:tabulate", - ":row_schema", - "//executorch/sdk/edir:base_schema", - "//executorch/sdk/edir:et_schema", - ], -) - python_library( name = "inspector", srcs = [ diff --git a/sdk/etdb/etdb.py b/sdk/etdb/etdb.py deleted file mode 100644 index c7b460e5d42..00000000000 --- a/sdk/etdb/etdb.py +++ /dev/null @@ -1,596 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import ( - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Union, -) - -from executorch.sdk.edir.base_schema import Node, OperatorGraph, OperatorNode, ValueNode -from executorch.sdk.edir.et_schema import PROFILE_STAT_HEADER, RESERVED_METADATA_ARG -from executorch.sdk.etdb.row_schema import ( - AbstractNodeInstanceRow, - GraphInstanceRow, - OpInstanceRow, - OpSummaryRow, - ValueInstanceRow, -) -from tabulate import tabulate - - -# Generate OpSummary Rows from Grouping by module type -def _gen_module_summaries(modules: List[GraphInstanceRow]) -> Dict[str, OpSummaryRow]: - # Group by module type - summaries = {} - for module in modules: - if (module_type := module.get_module_type()) not in summaries: - summaries[module_type] = OpSummaryRow(module_type, []) - summaries[module_type].elements.append(module) - - return summaries - - -# Generate OpSummary Rows from Grouping by operator type -# Extract stats from an aggregated ops summary table if provided -def _gen_op_summaries( - ops: List[OpInstanceRow], aggr_op_stats: Optional[List[Tuple[Any, ...]]] = None -) -> Dict[str, OpSummaryRow]: - # Group by operator type - grouped_ops = {} - for op in ops: - if op.op not in grouped_ops: - grouped_ops[op.op] = [] - grouped_ops[op.op].append(op) - - summaries = {} - - # Extract stats from the pregenerated table - # Note: These fields can be opaquely extracted from the pregenerated table - if aggr_op_stats is not None: - header = aggr_op_stats[0] - for row in aggr_op_stats[1:]: - op_name = row[header.index(PROFILE_STAT_HEADER.NAME.value)] - mean_ms = row[header.index(PROFILE_STAT_HEADER.MEAN_MS.value)] - min_ms = row[header.index(PROFILE_STAT_HEADER.MIN_MS.value)] - p10_ms = row[header.index(PROFILE_STAT_HEADER.P10_MS.value)] - p90_ms = row[header.index(PROFILE_STAT_HEADER.P90_MS.value)] - max_ms = row[header.index(PROFILE_STAT_HEADER.MAX_MS.value)] - summaries[op_name] = OpSummaryRow( - op_name, grouped_ops[op_name], mean_ms, min_ms, p10_ms, p90_ms, max_ms - ) - else: - for name, grouped_op in grouped_ops.items(): - summaries[name] = OpSummaryRow(name, grouped_op) - - return summaries - - -# Look for specific Pregenerated Tables, returning None if not found -def _extract_pre_generated_tables(metadata: Dict[str, Any]): - return ( - metadata.get("tables", {}).get(RESERVED_METADATA_ARG.AGGREGATED_OP_TABLE.value), - metadata.get("tables", {}).get(RESERVED_METADATA_ARG.RUN_SUMMARY_TABLE.value), - ) - - -# Given a list of InstanceRows, populate their output_nodes fields based on the input nodes -# of the other InstanceRows -def _populate_outputs( - instances_with_inputs: List[Union[OpInstanceRow, ValueInstanceRow]], - operator_instances: Dict[str, OpInstanceRow], - constant_instances: Dict[str, ValueInstanceRow], - input_instances: Dict[str, ValueInstanceRow], -): - for row in instances_with_inputs: - for input_node in row.input_nodes: - if input_node in operator_instances: - operator_instances[input_node].output_nodes.append(row.name) - elif input_node in constant_instances: - constant_instances[input_node].output_nodes.append(row.name) - elif input_node in input_instances: - input_instances[input_node].output_nodes.append(row.name) - - -# Print out all Initial ET graph tables -def _print_all( - input_instances: Dict[str, ValueInstanceRow], - operator_instances: Dict[str, Union[OpInstanceRow, GraphInstanceRow]], - operator_summary: Dict[str, OpSummaryRow], - output_instances: Dict[str, ValueInstanceRow], - verbose: bool = False, -): - print("Inputs") - print_rows(list(input_instances.values()), verbose) - - print("Operators") - print_rows(list(operator_instances.values()), verbose) - - print("Aggregated Operator/Module Summaries") - print_rows(list(operator_summary.values()), verbose) - - print("Outputs") - print_rows(list(output_instances.values()), verbose) - - -# Pyre is being weird with subclass typing: -# rows is List[AbstractInstanceRow] such that each row is the same subtype -# -# Give a list of AbstractInstanceRows, print in a table format -def print_rows(rows: List[Any], verbose: bool = False): - if len(rows) > 0: - print_table( - type(rows[0]).get_schema_header(verbose), - [entry.to_row_format(verbose) for entry in rows], - ) - - -# Table format used in ETDB -def print_table(header: List[str], rows: List[Sequence[Any]]): - empty_columns = [False] * len(header) - # Drop Columns with all None - for index in range(len(header)): - if all( - (index >= len(row) or row[index] is None or row[index] == "") - for row in rows - ): - empty_columns[index] = True - - header = [val for index, val in enumerate(header) if not empty_columns[index]] - rows = [ - [val for index, val in enumerate(row) if not empty_columns[index]] - for row in rows - ] - - print(tabulate(rows, headers=header, tablefmt="fancy_grid")) - - -# Given a list of row identifing strings, print out the rows -# within corresponding tables -def _print_related_rows( - inputs: List[str], - input_instances: Dict[str, ValueInstanceRow], - operator_instances: Dict[str, OpInstanceRow], - constant_instances: Dict[str, ValueInstanceRow], - operator_summary: Dict[str, OpSummaryRow], - output_instances: Dict[str, ValueInstanceRow], - subgraph_instances: Dict[str, GraphInstanceRow], - verbose: bool = False, -): - ops = [] - consts = [] - model_vals = [] - modules = [] - for in_arg in inputs: - if in_arg in operator_instances: - ops.append(operator_instances[in_arg]) - elif in_arg in constant_instances: - consts.append(constant_instances[in_arg]) - elif in_arg in input_instances: - model_vals.append(input_instances[in_arg]) - elif in_arg in output_instances: - model_vals.append(output_instances[in_arg]) - elif in_arg in subgraph_instances: - modules.append(subgraph_instances[in_arg]) - - tables = { - "Modules": modules, - "Operators": ops, - "Constants": consts, - "Model Values": model_vals, - } - for name, rows in tables.items(): - if len(rows) > 0: - print(name) - print_rows(rows, verbose) - - -# Evaluate the request if it is asking for a backstep in history -def _eval_backstep( - target: str, - history: List[str], - input_instances: Dict[str, ValueInstanceRow], - operator_instances: Dict[str, OpInstanceRow], - constant_instances: Dict[str, ValueInstanceRow], - operator_summary: Dict[str, OpSummaryRow], - output_instances: Dict[str, ValueInstanceRow], - subgraph_instances: Dict[str, GraphInstanceRow], - verbose: bool = False, -) -> bool: - back_representations = {"Back", "back", "b"} - if target not in back_representations: - return False - - if len(history) <= 1: - print("No history found") - else: - history.pop() - _eval( - history.pop(), - history, - input_instances, - operator_instances, - constant_instances, - operator_summary, - output_instances, - subgraph_instances, - verbose, - ) - - return True - - -# Evaluate the request if it is asking for an operator or module summary -def _eval_op_summary( - target: str, - history: List[str], - operator_summary: Dict[str, OpSummaryRow], - verbose: bool = False, -) -> bool: - if ( - target not in operator_summary - and (target := "[Sub Module] " + target) not in operator_summary - ): - return False - - selection = operator_summary[target] - - print("\nSelection\n---------") - print_rows([selection], verbose) - - print("\nElements\n---------") - if len(selection.elements) > 0: - print("Forward") - print_rows(selection.elements, verbose) - - history.append(target) - - return True - - -# Parse and perform a single pass of debug printing given a target input -# Request Types: -# - Backstep in History -# - Model Value (Input, Output) -# - Op Instance -# - Op Summary -# - Constant Instance -def _eval( # noqa C901 - target: str, - history: List[str], - input_instances: Dict[str, ValueInstanceRow], - operator_instances: Dict[str, OpInstanceRow], - constant_instances: Dict[str, ValueInstanceRow], - operator_summary: Dict[str, OpSummaryRow], - output_instances: Dict[str, ValueInstanceRow], - subgraph_instances: Dict[str, GraphInstanceRow], - verbose: bool = False, -): - # Evaluate if request is to backstep in history - if _eval_backstep( - target, - history, - input_instances, - operator_instances, - constant_instances, - operator_summary, - output_instances, - subgraph_instances, - verbose, - ): - return - - # Evaluate if request is for an op summary - if _eval_op_summary(target, history, operator_summary, verbose): - return - - # Check for valid inputs - selection = None - sources = [ - input_instances, - constant_instances, - output_instances, - operator_instances, - subgraph_instances, - ] - for source in sources: - if target in source: - selection = source[target] - - # Valid input not found - if selection is None: - print("Invalid Input") - return - - print("\nSelection\n---------") - print_rows([selection], verbose) - - def _curried_print_tables(inputs: List[str]): - _print_related_rows( - inputs, - input_instances, - operator_instances, - constant_instances, - operator_summary, - output_instances, - subgraph_instances, - verbose, - ) - - if target in subgraph_instances: - print("\nElements\n------") - # pyre-ignore isinstance doesn't play well with imported dataclasses - _curried_print_tables(selection.elements) - - input_nodes = selection.input_nodes - if len(input_nodes) > 0: - print("\nInputs\n------") - _curried_print_tables(input_nodes) - - output_nodes = selection.output_nodes - if len(output_nodes) > 0: - print("\nOutputs\n-------") - _curried_print_tables(output_nodes) - - module = selection.parent_graph - if module is not None: - print("\nParent Module\n-------------") - print_rows([subgraph_instances[module]], verbose) - - # Op Summary - if isinstance(selection, OpInstanceRow): - print("\nOperator Summary\n----------------") - print_rows([operator_summary[selection.op]], verbose) - - if isinstance(selection, GraphInstanceRow): - print("\nModule Summary\n--------------") - print_rows([operator_summary[selection.get_module_type()]], verbose) - - history.append(target) - - -# Loop containing the interactive debugging state and processing -def enter_interactive_debugging( - input_instances: Dict[str, ValueInstanceRow], - operator_instances: Dict[str, OpInstanceRow], - constant_instances: Dict[str, ValueInstanceRow], - operator_summary: Dict[str, OpSummaryRow], - output_instances: Dict[str, ValueInstanceRow], - subgraph_instances: Dict[str, GraphInstanceRow], - verbose: bool = False, -): - history = [] - - while True: - print( - "\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" - ) - target = input( - "Select one of the following (Node/Module Type, Node/Module Instance, Constant Name):\n> " - ) - _eval( - target, - history, - input_instances, - operator_instances, - constant_instances, - operator_summary, - output_instances, - subgraph_instances, - verbose, - ) - - -# Select one of the graphs for debugging -# TODO: Add the ability to toggle between graphs -def debug_graphs(graphs: Mapping[str, OperatorGraph], verbose: bool = False): - print("Graphs: ", "\t".join(graphs.keys())) - target = input("Select a graph to investigate:\n> ") - if target not in graphs: - target = input("Invalid Selection, Please try again:\n> ") - - debug_graph(graphs[target], verbose) - - -# Entry point for interactive debugging via ETDB -# Complexity lint, will be fixed in refactor -def debug_graph(graph: OperatorGraph, verbose: bool = False): # noqa C901 - # Visual Separator - print("\n") - - metadata = graph.metadata - aggregated_op_table, run_summary_table = None, None - if metadata is not None: - # Extract Pregenerated Tables - aggregated_op_table, run_summary_table = _extract_pre_generated_tables(metadata) - - # High level tables - top_graph_instances = [ - GraphInstanceRow.gen_from_operator_node(element) - for element in graph.elements - if isinstance(element, OperatorGraph) - ] - - if len(top_graph_instances) != len(graph.elements): - raise RuntimeError( - "Mixing Nodes and Graphs within OperatorGraph currently unsupported" - ) - - if run_summary_table is not None: - print("Run Summary Table") - print_table(run_summary_table[0], run_summary_table[1:]) - - # Print High Level Tables - print_table( - GraphInstanceRow.get_schema_header(verbose, count_format=True), - [ - entry.to_row_format(verbose, count_format=True) - for entry in top_graph_instances - ], - ) - - # Construct rows - input_instances: Dict[str, ValueInstanceRow] = {} - output_instances: Dict[str, ValueInstanceRow] = {} - operator_instances: Dict[str, OpInstanceRow] = {} - constant_instances: Dict[str, ValueInstanceRow] = {} - subgraph_instances: Dict[str, GraphInstanceRow] = {} - - # Given a string identifier, return the corresponding NodeInstanceRow - def find_instance_row(name: str) -> AbstractNodeInstanceRow: - sources = [ - input_instances, - output_instances, - operator_instances, - constant_instances, - subgraph_instances, - ] - for source in sources: - if name in source: - return source[name] - - raise Exception(f"Could not find row identified with {name}") - - # Convert the provided Node/OperatorGraph into an InstanceRow and update the - # corresponding data structures - def add_row_instance( - element: Union[Node, OperatorGraph], parent: Optional[str] = None - ) -> None: - if isinstance(element, ValueNode): - row = ValueInstanceRow.gen_from_operator_node(element, parent) - if sub_graph.graph_name == "inputs": - input_instances[element.name] = row - elif sub_graph.graph_name == "outputs": - output_instances[element.name] = row - else: - constant_instances[element.name] = row - elif isinstance(element, OperatorNode): - row = operator_instances[ - element.name - ] = OpInstanceRow.gen_from_operator_node(element, parent) - elif isinstance(element, OperatorGraph): - row = subgraph_instances[ - element.graph_name - ] = GraphInstanceRow.gen_from_operator_node(element, parent) - for child in element.elements: - add_row_instance(child, parent=element.graph_name) - else: - raise RuntimeError( - "Mixing Nodes and Graphs within OperatorGraph currently unsupported" - ) - - for sub_graph in graph.elements: - # Enforced above - assert isinstance(sub_graph, OperatorGraph) - for element in sub_graph.elements: - add_row_instance(element) - - # Populate Output - instances_with_inputs = list(output_instances.values()) + list( - operator_instances.values() - ) - _populate_outputs( - instances_with_inputs, operator_instances, constant_instances, input_instances - ) - - def collect_recursively( - instance_node: AbstractNodeInstanceRow, - graph_fn: Callable[[GraphInstanceRow], List[str]], - general_fn: Callable[[AbstractNodeInstanceRow], List[str]], - ) -> Set[str]: - """ - Recursively collect the results of applying graph_fn/general_fn on - a NodeInstanceRow - - - If instance_node is not a subgraph: return the result of general_fn(node) - - If instance_node is a subgraph: return the result of graph_fn(node) plus the - recursive output of this functions on node.elements - - Returns a set of unique strings curated from recursing through the provided node - """ - collection = set() - - if instance_node.get_name() in subgraph_instances: - # pyre-ignore isinstance doesn't play well with imported dataclasses - collection.update(graph_fn(instance_node)) - # pyre-ignore - for node in instance_node.elements: - collection.update( - collect_recursively(find_instance_row(node), graph_fn, general_fn) - ) - else: - collection.update(general_fn(instance_node)) - - return collection - - # Populate Input/Output of subgraphs - for sub_graph in subgraph_instances.values(): - descendents = collect_recursively( - sub_graph, (lambda node: node.elements), (lambda node: []) - ) - descendent_inputs = collect_recursively( - sub_graph, - (lambda node: node.get_input_nodes()), - (lambda node: node.get_input_nodes()), - ) - descendent_outputs = collect_recursively( - sub_graph, - (lambda node: node.get_output_nodes()), - (lambda node: node.get_output_nodes()), - ) - - true_inputs = descendent_inputs - descendents - true_outputs = descendent_outputs - descendents - - sub_graph.input_nodes = list(true_inputs) - sub_graph.output_nodes = list(true_outputs) - - # Generate Summary Table - operator_summary = { - **_gen_op_summaries(list(operator_instances.values()), aggregated_op_table), - **{**_gen_module_summaries(list(subgraph_instances.values()))}, - } - - # Replace Operators with their parent groups - collapsed_operators = { - **{ - op_str: op - for op_str, op in operator_instances.items() - if op.parent_graph is None - }, - **{ - op_str: op - for op_str, op in subgraph_instances.items() - if op.parent_graph is None - }, - } - - # Print out tables - _print_all( - input_instances, - collapsed_operators, - operator_summary, - output_instances, - verbose, - ) - - # Start interactive debugging mode - enter_interactive_debugging( - input_instances, - operator_instances, - constant_instances, - operator_summary, - output_instances, - subgraph_instances, - verbose, - ) diff --git a/sdk/etdb/row_schema.py b/sdk/etdb/row_schema.py deleted file mode 100644 index 716c7de50e0..00000000000 --- a/sdk/etdb/row_schema.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -from executorch.sdk.edir.base_schema import OperatorGraph, OperatorNode, ValueNode -from executorch.sdk.edir.et_schema import RESERVED_METADATA_ARG - - -@dataclass -class AbstractInstanceRow: - @staticmethod - def get_schema_header(verbose: bool = False) -> List[str]: - pass - - def to_row_format(self, verbose=False) -> Tuple[Any, ...]: - pass - - -@dataclass -class AbstractNodeInstanceRow(AbstractInstanceRow): - def get_name(self) -> str: - pass - - def get_input_nodes(self) -> List[str]: - pass - - def get_output_nodes(self) -> List[str]: - pass - - -@dataclass -class GraphInstanceRow(AbstractNodeInstanceRow): - name: str - elements: List[str] = field(default_factory=list) - metadata: Optional[Dict[str, Any]] = None - parent_graph: Optional[str] = None - - # Element counts - graph_count: int = 0 - operator_count: int = 0 - constant_count: int = 0 - - # Name identified rows - input_nodes: List[str] = field(default_factory=list) - output_nodes: List[str] = field(default_factory=list) - - @staticmethod - def gen_from_operator_node( - op_graph: OperatorGraph, parent_graph: Optional[str] = None - ) -> "GraphInstanceRow": - graph_count = 0 - operator_count = 0 - constant_count = 0 - elements = [] - for e in op_graph.elements: - if isinstance(e, OperatorGraph): - elements.append(e.graph_name) - graph_count += 1 - elif isinstance(e, OperatorNode): - elements.append(e.name) - operator_count += 1 - elif isinstance(e, ValueNode): - elements.append(e.name) - constant_count += 1 - - return GraphInstanceRow( - op_graph.graph_name, - elements, - op_graph.metadata, - parent_graph, - graph_count, - operator_count, - constant_count, - ) - - @staticmethod - def get_schema_header(verbose: bool = False, count_format=False) -> List[str]: - # (!!) Format to coincide with OpInstanceRow format - if not count_format: - return OpInstanceRow.get_schema_header(verbose) - - if verbose: - return [ - "Parent Graph", - "Name", - "Element Count", - "Graph Count", - "Operator Count", - "Constant Count", - ] - return ["Name", "Graph Count", "Operator Count"] - - def get_module_type(self) -> str: - return "[Sub Module] " + ( - self.metadata.get("module_type", "") if self.metadata is not None else "" - ) - - def to_row_format(self, verbose=False, count_format=False) -> Tuple[Any, ...]: - if count_format: - element_count = len(self.elements) if self.elements else 0 - if verbose: - return ( - self.parent_graph, - self.name, - element_count, - self.graph_count, - self.operator_count, - self.constant_count, - ) - return (self.name, self.graph_count, self.operator_count) - - # (!!) Format to coincide with OpInstanceRow format - module_type = self.get_module_type() - input_node_str = "\n".join(self.input_nodes) - output_node_str = "\n".join(self.output_nodes) - - if verbose: - return ( - self.parent_graph, - module_type, - self.name, - len(self.input_nodes), - len(self.output_nodes), - "", - input_node_str, - output_node_str, - ) - return (module_type, self.name, input_node_str, output_node_str) - - def get_name(self) -> str: - return self.name - - def get_input_nodes(self) -> List[str]: - return self.input_nodes - - def get_output_nodes(self) -> List[str]: - return self.output_nodes - - -@dataclass -class ValueInstanceRow(AbstractNodeInstanceRow): - dtype: str - name: str - val: Any - metadata: Optional[Dict[str, Any]] = None - parent_graph: Optional[str] = None - - # Name identified rows - input_nodes: List[str] = field(default_factory=list) - output_nodes: List[str] = field(default_factory=list) - - # Note: This does not populate output nodes - @staticmethod - def gen_from_operator_node( - value_node: ValueNode, parent_graph: Optional[str] = None - ) -> "ValueInstanceRow": - input_nodes = [e.name for e in value_node.inputs] if value_node.inputs else [] - return ValueInstanceRow( - value_node.dtype, - value_node.name, - value_node.val, - value_node.metadata, - parent_graph, - input_nodes, - [], - ) - - @staticmethod - def get_schema_header(verbose: bool = False) -> List[str]: - if verbose: - return [ - "Parent Graph", - "Dtype", - "Name", - "Value (Shape if Tensor)", - "Input Count", - "Output Count", - "Input Nodes", - "Output Nodes", - ] - return [ - "Dtype", - "Name", - "Value (Shape if Tensor)", - "Input Nodes", - "Output Nodes", - ] - - def to_row_format(self, verbose=False) -> Tuple[Any, ...]: - row = (self.dtype, self.name, self.val) - if verbose: - row = (self.parent_graph,) + row - row += (len(self.input_nodes), len(self.output_nodes)) - - return row + ("\n".join(self.input_nodes), "\n".join(self.output_nodes)) - - def get_name(self) -> str: - return self.name - - def get_input_nodes(self) -> List[str]: - return self.input_nodes - - def get_output_nodes(self) -> List[str]: - return self.output_nodes - - -@dataclass -class OpInstanceRow(AbstractNodeInstanceRow): - op: str - name: str - output_shapes: Optional[List[List[int]]] = None - metadata: Optional[Dict[str, Any]] = None - parent_graph: Optional[str] = None - - # Name identified rows - input_nodes: List[str] = field(default_factory=list) - output_nodes: List[str] = field(default_factory=list) - - @staticmethod - def gen_from_operator_node( - operator_node: OperatorNode, parent_graph: Optional[str] = None - ) -> "OpInstanceRow": - op = operator_node.op if operator_node.op is not None else "Unknown" - input_nodes = ( - [e.name for e in operator_node.inputs] if operator_node.inputs else [] - ) - return OpInstanceRow( - op, - operator_node.name, - operator_node.output_shapes, - operator_node.metadata, - parent_graph, - input_nodes, - ) - - @staticmethod - def get_schema_header(verbose: bool = False) -> List[str]: - if verbose: - return [ - "Parent Group", - "Op (Module if annotated)", - "Name", - "Input Count", - "Output Count", - "Output Shapes", - "Input Nodes", - "Output Nodes", - "Coldstart (ms)", - "Mean (ms)", - "Min (ms)", - "P10 (ms)", - "P90 (ms)", - "Max (ms)", - ] - return [ - "Op (Module if annotated)", - "Name", - "Input Nodes", - "Output Nodes", - ] - - def to_row_format(self, verbose=False) -> Tuple[Any, ...]: - input_node_str = "\n".join(self.input_nodes) - output_node_str = "\n".join(self.output_nodes) - - output_shape_str = ( - "\n".join([str(shape) for shape in self.output_shapes]) - if self.output_shapes is not None - else "" - ) - - if verbose: - row = ( - self.parent_graph, - self.op, - self.name, - len(self.input_nodes), - len(self.output_nodes), - output_shape_str, - input_node_str, - output_node_str, - ) - - # Note: These fields can be opaquely extracted from a keyed metadta field - metadata = self.metadata - if ( - metadata is not None - and RESERVED_METADATA_ARG.METRICS_KEYWORD.value in metadata - ): - metrics = metadata[RESERVED_METADATA_ARG.METRICS_KEYWORD.value] - coldstart_ms = metrics.get( - RESERVED_METADATA_ARG.PROFILE_SUMMARY_COLDSTART.value, None - ) - mean_ms = metrics.get( - RESERVED_METADATA_ARG.PROFILE_SUMMARY_AVERAGE.value, None - ) - min_ms = metrics.get( - RESERVED_METADATA_ARG.PROFILE_SUMMARY_MIN.value, None - ) - p10_ms = metrics.get( - RESERVED_METADATA_ARG.PROFILE_SUMMARY_P10.value, None - ) - p90_ms = metrics.get( - RESERVED_METADATA_ARG.PROFILE_SUMMARY_P90.value, None - ) - max_ms = metrics.get( - RESERVED_METADATA_ARG.PROFILE_SUMMARY_MAX.value, None - ) - row += (coldstart_ms, mean_ms, min_ms, p10_ms, p90_ms, max_ms) - - return row - - return (self.op, self.name, input_node_str, output_node_str) - - def get_name(self) -> str: - return self.name - - def get_input_nodes(self) -> List[str]: - return self.input_nodes - - def get_output_nodes(self) -> List[str]: - return self.output_nodes - - -@dataclass -class OpSummaryRow(AbstractInstanceRow): - op: str - elements: List[AbstractNodeInstanceRow] - - # Summary Stats - mean_ms: Optional[float] = None - min_ms: Optional[float] = None - p10_ms: Optional[float] = None - p90_ms: Optional[float] = None - max_ms: Optional[float] = None - - @staticmethod - def get_schema_header(verbose: bool = False) -> List[str]: - return [ - "Op (Module if annotated)", - "Instance Count", - "Mean (ms)", - "Min (ms)", - "P10 (ms)", - "P90 (ms)", - "Max (ms)", - ] - - def to_row_format(self, verbose=False) -> Tuple[Any, ...]: - return ( - self.op, - len(self.elements), - self.mean_ms, - self.min_ms, - self.p10_ms, - self.p90_ms, - self.max_ms, - ) diff --git a/sdk/lib.py b/sdk/lib.py deleted file mode 100644 index 6a1656c545c..00000000000 --- a/sdk/lib.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import asyncio -import os -from typing import Mapping, Optional - -from executorch.sdk.edir.et_schema import ( - FXOperatorGraph, - InferenceRun, - OperatorGraphWithStats, -) -from executorch.sdk.etdb.etdb import debug_graphs -from executorch.sdk.etrecord import ETRecord, parse_etrecord - -""" -Private Lib Helpers -""" - - -def _gen_graphs_from_etrecord( - etrecord: ETRecord, -) -> Mapping[str, OperatorGraphWithStats]: - if etrecord.graph_map is None: - return {} - return { - name: FXOperatorGraph.gen_operator_graph(exported_program.graph_module) - for name, exported_program in etrecord.graph_map.items() - } - - -def _gen_and_attach_metadata( - op_graph_dict: Mapping[str, OperatorGraphWithStats], et_dump_path: str -) -> None: - """ - (!!) Note: Currently we only support attaching etdump data to the - et_dialect_graph_module. - - Attach metadata in ETDump under path et_dump_path to the given op_graph. - To visualize op_graph without ETDump metadata, this function can be skipped. - - Args: - op_graph (ExportedETOperatorGraph): operator graph to visualize - et_dump_path (str): local or Manifold path to the ETDump - """ - - op_graph = op_graph_dict["et_dialect_graph_module/forward"] - - if os.path.exists(et_dump_path): - op_graph.attach_metadata( - inference_run=InferenceRun.extract_runs_from_path(file_path=et_dump_path)[0] - ) - else: - raise Exception("Invalid ET Dump path") - - -""" -SDK Entry Points -""" - - -def debug_etrecord( - etrecord: ETRecord, et_dump_path: Optional[str] = None, verbose: bool = False -): - """ - Given an ETRecord, kick off ETDB - """ - op_graph_dict: Mapping[str, OperatorGraphWithStats] = _gen_graphs_from_etrecord( - etrecord - ) - if et_dump_path is not None: - _gen_and_attach_metadata(op_graph_dict, et_dump_path) - debug_graphs(op_graph_dict, verbose) - - -def debug_etrecord_path( - etrecord_path: str, et_dump_path: Optional[str] = None, verbose: bool = False -): - """ - Given a path to an ETRecord, kick off ETDB - """ - debug_etrecord(parse_etrecord(etrecord_path), et_dump_path, verbose) - - -""" -SDK Binary -""" - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("et_record", help="Path to ETRecord") - parser.add_argument("--et_dump", help="Path to ET Dump") - parser.add_argument( - "--verbose", - action="store_true", - help="Whether the terminal should display in verbose mode", - ) - return parser.parse_args() - - -async def main() -> int: - """ - Simple CLI wrapper for triggering ETDB - - Only required argument is an et_record path - """ - args = parse_args() - debug_etrecord_path(args.et_record, args.et_dump, args.verbose) - return 0 - - -if __name__ == "__main__": - asyncio.run(main())