Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import List, Optional
from typing import Dict, List, Optional, Union

from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
from executorch.exir.pass_manager import PassType
Expand Down Expand Up @@ -45,7 +45,12 @@ class EdgeCompileConfig:
@dataclass
class ExecutorchBackendConfig:
passes: List[PassType] = field(default_factory=list)
memory_planning_pass: PassType = MemoryPlanningPass("greedy")

# A single memory planning pass can be defined for all the programs in the
# EdgeProgramManager or can be defined per program.
memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass(
"greedy"
)
to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False)
dynamic_memory_planning_mode: DynamicMemoryPlanningMode = (
DynamicMemoryPlanningMode.UPPER_BOUND
Expand Down
17 changes: 12 additions & 5 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def to_executorch(
# Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet.
# After exir.capture is gone I will clean up the memory planning infra to be consistent.
# Frankly all of exir has big code quality issues because of the migrations that need to be addressed.
new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19]
new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29]
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
new_prog = ExirExportedProgram(
Expand Down Expand Up @@ -889,7 +889,8 @@ def to_backend(
)

def to_executorch(
self, config: Optional[ExecutorchBackendConfig] = None
self,
config: Optional[ExecutorchBackendConfig] = None,
) -> "ExecutorchProgramManager":
"""
Transforms the program to the ExecuTorch backend.
Expand Down Expand Up @@ -926,13 +927,19 @@ def to_executorch(
# TODO(who?)
p.update_placeholder_tensor_specs(program, new_gm)

if isinstance(config.memory_planning_pass, dict):
memory_planning_pass = config.memory_planning_pass.get(
name, ExecutorchBackendConfig().memory_planning_pass
)
else:
memory_planning_pass = config.memory_planning_pass
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
if hasattr(config.memory_planning_pass, "run"):
new_gm_res = config.memory_planning_pass.run( # pyre-ignore[16]
if hasattr(memory_planning_pass, "run"):
new_gm_res = memory_planning_pass.run( # pyre-ignore[16]
new_gm, new_signature
)
else:
new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19]
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
assert new_gm_res is not None
new_gm = new_gm_res.graph_module

Expand Down
40 changes: 40 additions & 0 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from executorch.exir.error import ExportError
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.program._program import (
EdgeProgramManager,
ExecutorchProgramManager,
Expand Down Expand Up @@ -160,6 +161,45 @@ def test_executorch_manager_basic_api(self):
3,
)

def test_executorch_manager_multi_config(self):
def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:
return {
"forward": MemoryPlanningPass(
memory_planning_algo="greedy",
alloc_graph_input=True,
alloc_graph_output=False,
),
"foo": MemoryPlanningPass(
memory_planning_algo="greedy",
alloc_graph_input=False,
alloc_graph_output=True,
),
}

executorch_manager: ExecutorchProgramManager = to_edge(
get_exported_programs(), get_config_methods()
).to_executorch(
ExecutorchBackendConfig(
memory_planning_pass=get_executorch_memory_planning_passes()
)
)

method = executorch_manager._emitter_output.program.execution_plan[0]
if method.name == "forward":
for input_val in method.inputs:
evalue = method.values[input_val]
self.assertEqual(evalue.val.allocation_info, None)
for output_val in method.outputs:
evalue = method.values[output_val]
self.assertNotEqual(evalue.val.allocation_info, None)
else:
for input_val in method.inputs:
evalue = method.values[input_val]
self.assertEqual(evalue.val.allocation_info, None)
for output_val in method.outputs:
evalue = method.values[output_val]
self.assertNotEqual(evalue.val.allocation_info, None)

def test_no_getattr(self):
class Mul(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down