diff --git a/sdk/inspector/TARGETS b/sdk/inspector/TARGETS index 965e53ff77d..e87e9646877 100644 --- a/sdk/inspector/TARGETS +++ b/sdk/inspector/TARGETS @@ -16,8 +16,9 @@ python_library( ":inspector_utils", "//caffe2:torch", "//executorch/exir:lib", - "//executorch/sdk/edir:base_schema", + "//executorch/sdk/edir:et_schema", "//executorch/sdk/etdump:schema_flatcc", + "//executorch/sdk/etrecord:etrecord", ], ) diff --git a/sdk/inspector/_inspector_utils.py b/sdk/inspector/_inspector_utils.py index e1054f89876..cfa38b9da13 100644 --- a/sdk/inspector/_inspector_utils.py +++ b/sdk/inspector/_inspector_utils.py @@ -12,7 +12,7 @@ from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc -from executorch.sdk.etrecord import ETRecord, parse_etrecord +from executorch.sdk.etrecord import ETRecord EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module" @@ -61,13 +61,6 @@ def create_debug_handle_to_op_node_mapping( debug_handle_to_op_node_map[debug_handle] = element -def gen_etrecord_object(etrecord_path: Optional[str] = None) -> ETRecord: - # Gen op graphs from etrecord - if etrecord_path is None: - raise ValueError("Etrecord_path must be specified.") - return parse_etrecord(etrecord_path=etrecord_path) - - def gen_etdump_object(etdump_path: Optional[str] = None) -> ETDumpFlatCC: # Gen event blocks from etdump if etdump_path is None: diff --git a/sdk/inspector/inspector.py b/sdk/inspector/inspector.py index 18c524fbcef..a48ff1f2d7a 100644 --- a/sdk/inspector/inspector.py +++ b/sdk/inspector/inspector.py @@ -25,13 +25,13 @@ import torch from executorch.exir import ExportedProgram -from executorch.sdk.edir.base_schema import OperatorGraph, OperatorNode +from executorch.sdk.edir.et_schema import OperatorNode from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent +from executorch.sdk.etrecord import parse_etrecord from executorch.sdk.inspector._inspector_utils import ( create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, gen_etdump_object, - gen_etrecord_object, gen_graphs_from_etrecord, ) @@ -368,7 +368,6 @@ class Inspector: Private Attributes: _etrecord: Optional[ETRecord]. File under etrecord_path deserialized into an object. - _op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names. """ def __init__( @@ -387,14 +386,18 @@ def __init__( defaults to milli (1000ms = 1s). """ - # TODO: etrecord_path can be optional, so need to support the case when it is not present - self._etrecord = gen_etrecord_object(etrecord_path=etrecord_path) + self._etrecord = ( + parse_etrecord(etrecord_path=etrecord_path) + if etrecord_path is not None + else None + ) + etdump = gen_etdump_object(etdump_path=etdump_path) self.event_blocks = EventBlock._gen_from_etdump(etdump, etdump_scale) - self._op_graph_dict: Mapping[str, OperatorGraph] = gen_graphs_from_etrecord( - etrecord=self._etrecord - ) + # No additional data association can be done without ETRecord, so return early + if self._etrecord is None: + return # Use the delegate map from etrecord, associate debug handles with each event for event_block in self.event_blocks: @@ -406,9 +409,10 @@ def __init__( ) # Traverse the edge dialect op graph to create mapping from debug_handle to op node + op_graph_dict = gen_graphs_from_etrecord(etrecord=self._etrecord) debug_handle_to_op_node_map = {} create_debug_handle_to_op_node_mapping( - self._op_graph_dict[EDGE_DIALECT_GRAPH_KEY], + op_graph_dict[EDGE_DIALECT_GRAPH_KEY], debug_handle_to_op_node_map, ) @@ -479,13 +483,22 @@ def write_tensorboard_artifact(self, path: str) -> None: # TODO: implement pass - def get_exported_program(self, graph: Optional[str] = None) -> ExportedProgram: + def get_exported_program( + self, graph: Optional[str] = None + ) -> Optional[ExportedProgram]: """ Access helper for ETRecord, defaults to returning Edge Dialect Program Args: graph: Name of the graph to access. If None, returns the Edge Dialect Program. """ - if graph is None: - return self._etrecord.edge_dialect_program - return self._etrecord.graph_map.get(graph) + if self._etrecord is None: + log.warning( + "Exported program is only available when a valid etrecord_path was provided at the time of Inspector construction" + ) + return None + return ( + self._etrecord.edge_dialect_program + if graph is None + else self._etrecord.graph_map.get(graph) + ) diff --git a/sdk/inspector/tests/inspector_test.py b/sdk/inspector/tests/inspector_test.py index 494489a3b36..e55984fe952 100644 --- a/sdk/inspector/tests/inspector_test.py +++ b/sdk/inspector/tests/inspector_test.py @@ -54,8 +54,8 @@ def test_event_block_to_dataframe(self) -> None: def test_inspector_constructor(self): # Create a context manager to patch functions called by Inspector.__init__ with patch.object( - inspector, "gen_etrecord_object", return_value=None - ) as mock_gen_etrecord, patch.object( + inspector, "parse_etrecord", return_value=None + ) as mock_parse_etrecord, patch.object( inspector, "gen_etdump_object", return_value=None ) as mock_gen_etdump, patch.object( EventBlock, "_gen_from_etdump" @@ -69,20 +69,17 @@ def test_inspector_constructor(self): ) # Assert that expected functions are called - mock_gen_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH) + mock_parse_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH) mock_gen_etdump.assert_called_once_with(etdump_path=ETDUMP_PATH) mock_gen_from_etdump.assert_called_once() - mock_gen_graphs_from_etrecord.assert_called_once() + # Because we mocked parse_etrecord() to return None, this method shouldn't be called + mock_gen_graphs_from_etrecord.assert_not_called() def test_inspector_get_event_blocks_and_print_data_tabular(self): # Create a context manager to patch functions called by Inspector.__init__ - with patch.object( - inspector, "gen_etrecord_object", return_value=None - ), patch.object( + with patch.object(inspector, "parse_etrecord", return_value=None), patch.object( inspector, "gen_etdump_object", return_value=None - ), patch.object( - EventBlock, "_gen_from_etdump" - ), patch.object( + ), patch.object(EventBlock, "_gen_from_etdump"), patch.object( inspector, "gen_graphs_from_etrecord" ): # Call the constructor of Inspector @@ -189,13 +186,9 @@ def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self): def test_inspector_get_exported_program(self): # Create a context manager to patch functions called by Inspector.__init__ - with patch.object( - inspector, "gen_etrecord_object", return_value=None - ), patch.object( + with patch.object(inspector, "parse_etrecord", return_value=None), patch.object( inspector, "gen_etdump_object", return_value=None - ), patch.object( - EventBlock, "_gen_from_etdump" - ), patch.object( + ), patch.object(EventBlock, "_gen_from_etdump"), patch.object( inspector, "gen_graphs_from_etrecord" ), patch.object( inspector, "create_debug_handle_to_op_node_mapping"