From d6c1a3d5e65fb8559f44169089eae68580e30c2a Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Thu, 9 May 2024 13:45:59 -0700 Subject: [PATCH] Add support for method level executorch backend config (#3266) Summary: There are use cases where we might like to supply a separate ExecutorchBackendConfig for each method in the model. An example use case is where we might want to alloc inputs for one method and not alloc them for another. In order to support this, in this diff we add support for passing in a dictionary of configs to `to_executorch`. Reviewed By: JacobSzwejbka, cccclai Differential Revision: D56499598 --- exir/capture/_config.py | 9 +++++-- exir/program/_program.py | 17 +++++++++---- exir/program/test/test_program.py | 40 +++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index c03be0e24f3..dd0ed94094f 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from typing import List, Optional +from typing import Dict, List, Optional, Union from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.pass_manager import PassType @@ -45,7 +45,12 @@ class EdgeCompileConfig: @dataclass class ExecutorchBackendConfig: passes: List[PassType] = field(default_factory=list) - memory_planning_pass: PassType = MemoryPlanningPass("greedy") + + # A single memory planning pass can be defined for all the programs in the + # EdgeProgramManager or can be defined per program. + memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass( + "greedy" + ) to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False) dynamic_memory_planning_mode: DynamicMemoryPlanningMode = ( DynamicMemoryPlanningMode.UPPER_BOUND diff --git a/exir/program/_program.py b/exir/program/_program.py index f2c2a5438fd..c5afe011691 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -412,7 +412,7 @@ def to_executorch( # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet. # After exir.capture is gone I will clean up the memory planning infra to be consistent. # Frankly all of exir has big code quality issues because of the migrations that need to be addressed. - new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19] + new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29] assert new_gm_res is not None new_gm = new_gm_res.graph_module new_prog = ExirExportedProgram( @@ -889,7 +889,8 @@ def to_backend( ) def to_executorch( - self, config: Optional[ExecutorchBackendConfig] = None + self, + config: Optional[ExecutorchBackendConfig] = None, ) -> "ExecutorchProgramManager": """ Transforms the program to the ExecuTorch backend. @@ -926,13 +927,19 @@ def to_executorch( # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) + if isinstance(config.memory_planning_pass, dict): + memory_planning_pass = config.memory_planning_pass.get( + name, ExecutorchBackendConfig().memory_planning_pass + ) + else: + memory_planning_pass = config.memory_planning_pass # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work - if hasattr(config.memory_planning_pass, "run"): - new_gm_res = config.memory_planning_pass.run( # pyre-ignore[16] + if hasattr(memory_planning_pass, "run"): + new_gm_res = memory_planning_pass.run( # pyre-ignore[16] new_gm, new_signature ) else: - new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19] + new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] assert new_gm_res is not None new_gm = new_gm_res.graph_module diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 51f0fcf0788..f84f0c1cd02 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -16,6 +16,7 @@ from executorch.exir.error import ExportError from executorch.exir.lowered_backend_module import get_lowered_submodules from executorch.exir.pass_base import ExportPass +from executorch.exir.passes import MemoryPlanningPass from executorch.exir.program._program import ( EdgeProgramManager, ExecutorchProgramManager, @@ -160,6 +161,45 @@ def test_executorch_manager_basic_api(self): 3, ) + def test_executorch_manager_multi_config(self): + def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: + return { + "forward": MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=True, + alloc_graph_output=False, + ), + "foo": MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=True, + ), + } + + executorch_manager: ExecutorchProgramManager = to_edge( + get_exported_programs(), get_config_methods() + ).to_executorch( + ExecutorchBackendConfig( + memory_planning_pass=get_executorch_memory_planning_passes() + ) + ) + + method = executorch_manager._emitter_output.program.execution_plan[0] + if method.name == "forward": + for input_val in method.inputs: + evalue = method.values[input_val] + self.assertEqual(evalue.val.allocation_info, None) + for output_val in method.outputs: + evalue = method.values[output_val] + self.assertNotEqual(evalue.val.allocation_info, None) + else: + for input_val in method.inputs: + evalue = method.values[input_val] + self.assertEqual(evalue.val.allocation_info, None) + for output_val in method.outputs: + evalue = method.values[output_val] + self.assertNotEqual(evalue.val.allocation_info, None) + def test_no_getattr(self): class Mul(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: