diff --git a/python/tvm/driver/tvmc/workspace_pools.py b/python/tvm/driver/tvmc/workspace_pools.py index 2c91488fb48b..fe304f7fc0af 100644 --- a/python/tvm/driver/tvmc/workspace_pools.py +++ b/python/tvm/driver/tvmc/workspace_pools.py @@ -161,16 +161,6 @@ def workspace_pools_recombobulate(parsed, targets, extra_target): "workspace_pools_target_burst_bytes", ] - # Load extra targets from CLI - additional_targets = [] - - for t in extra_target: - additional_targets.append(Target(t["raw"], host=targets[0].host or targets[0])) - - target = targets + additional_targets - if targets[0].host: - target.append(targets[0].host) - workspace_pools = _split_pools_to_pool_names(parsed.workspace_pools) if not workspace_pools: return None @@ -186,6 +176,16 @@ def workspace_pools_recombobulate(parsed, targets, extra_target): for workspace_pool_param in WORKSPACE_POOL_TARGET_PARAMS } + # Load extra targets from CLI + additional_targets = [] + + for t in extra_target: + additional_targets.append(Target(t["raw"], host=targets[0].host or targets[0])) + + target = targets + additional_targets + if targets[0].host: + target.append(targets[0].host) + return WorkspaceMemoryPools( [ WorkspacePoolInfo( diff --git a/tests/python/driver/tvmc/test_workspace_pools.py b/tests/python/driver/tvmc/test_workspace_pools.py index 386181aaf20b..2e34c90252c3 100644 --- a/tests/python/driver/tvmc/test_workspace_pools.py +++ b/tests/python/driver/tvmc/test_workspace_pools.py @@ -18,6 +18,7 @@ import pytest import argparse +import tvm from tvm.driver.tvmc.workspace_pools import ( generate_workspace_pools_args, workspace_pools_recombobulate, @@ -402,3 +403,18 @@ def test_workspace_pools_recombobulate_single_pool_overrides(): assert len(memory_pools.pools[0].targets) == 2 assert len(memory_pools.pools[1].targets) == 1 + + +@tvm.testing.requires_ethosn +def test_workspace_pools_recombobulate_ext_codegen(): + """No error should occur when using an external code generator without an attached Target""" + + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args([]) + + targets = [Target("llvm")] + extra_targets = [{"raw": "ethos-n"}] + + memory_pools = workspace_pools_recombobulate(parsed, targets, extra_targets) + assert memory_pools is None