diff --git a/exir/capture/_config.py b/exir/capture/_config.py index c03be0e24f3..dd0ed94094f 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from typing import List, Optional +from typing import Dict, List, Optional, Union from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.pass_manager import PassType @@ -45,7 +45,12 @@ class EdgeCompileConfig: @dataclass class ExecutorchBackendConfig: passes: List[PassType] = field(default_factory=list) - memory_planning_pass: PassType = MemoryPlanningPass("greedy") + + # A single memory planning pass can be defined for all the programs in the + # EdgeProgramManager or can be defined per program. + memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass( + "greedy" + ) to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False) dynamic_memory_planning_mode: DynamicMemoryPlanningMode = ( DynamicMemoryPlanningMode.UPPER_BOUND diff --git a/exir/program/_program.py b/exir/program/_program.py index f2c2a5438fd..c5afe011691 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -412,7 +412,7 @@ def to_executorch( # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet. # After exir.capture is gone I will clean up the memory planning infra to be consistent. # Frankly all of exir has big code quality issues because of the migrations that need to be addressed. - new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19] + new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29] assert new_gm_res is not None new_gm = new_gm_res.graph_module new_prog = ExirExportedProgram( @@ -889,7 +889,8 @@ def to_backend( ) def to_executorch( - self, config: Optional[ExecutorchBackendConfig] = None + self, + config: Optional[ExecutorchBackendConfig] = None, ) -> "ExecutorchProgramManager": """ Transforms the program to the ExecuTorch backend. @@ -926,13 +927,19 @@ def to_executorch( # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) + if isinstance(config.memory_planning_pass, dict): + memory_planning_pass = config.memory_planning_pass.get( + name, ExecutorchBackendConfig().memory_planning_pass + ) + else: + memory_planning_pass = config.memory_planning_pass # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work - if hasattr(config.memory_planning_pass, "run"): - new_gm_res = config.memory_planning_pass.run( # pyre-ignore[16] + if hasattr(memory_planning_pass, "run"): + new_gm_res = memory_planning_pass.run( # pyre-ignore[16] new_gm, new_signature ) else: - new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19] + new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] assert new_gm_res is not None new_gm = new_gm_res.graph_module diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 51f0fcf0788..f84f0c1cd02 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -16,6 +16,7 @@ from executorch.exir.error import ExportError from executorch.exir.lowered_backend_module import get_lowered_submodules from executorch.exir.pass_base import ExportPass +from executorch.exir.passes import MemoryPlanningPass from executorch.exir.program._program import ( EdgeProgramManager, ExecutorchProgramManager, @@ -160,6 +161,45 @@ def test_executorch_manager_basic_api(self): 3, ) + def test_executorch_manager_multi_config(self): + def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: + return { + "forward": MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=True, + alloc_graph_output=False, + ), + "foo": MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=True, + ), + } + + executorch_manager: ExecutorchProgramManager = to_edge( + get_exported_programs(), get_config_methods() + ).to_executorch( + ExecutorchBackendConfig( + memory_planning_pass=get_executorch_memory_planning_passes() + ) + ) + + method = executorch_manager._emitter_output.program.execution_plan[0] + if method.name == "forward": + for input_val in method.inputs: + evalue = method.values[input_val] + self.assertEqual(evalue.val.allocation_info, None) + for output_val in method.outputs: + evalue = method.values[output_val] + self.assertNotEqual(evalue.val.allocation_info, None) + else: + for input_val in method.inputs: + evalue = method.values[input_val] + self.assertEqual(evalue.val.allocation_info, None) + for output_val in method.outputs: + evalue = method.values[output_val] + self.assertNotEqual(evalue.val.allocation_info, None) + def test_no_getattr(self): class Mul(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: