Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
sram = extract_memory_info(workspace_memory_pools.pools[0], memory_pressure)
tir_mod = LowerToTIR(_ethos_u55_cascader(sram, util.is_striping_enabled()))(mod)
else:
tir_mod = LowerToTIR(copy_constants())(mod)
scheduler = None if util.is_copying_constants_disabled() else copy_constants()
tir_mod = LowerToTIR(scheduler)(mod)

return tir_mod

Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def is_cascader_enabled():
return compiler_attrs.enable_cascader


def is_copying_constants_disabled() -> bool:
"""Determine whether copying constants is disabled for case without cascader"""
compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
return bool(compiler_attrs.disable_copying_constants)


def is_striping_enabled():
"""Determine whether the cascader is enabled"""
compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
Expand Down
9 changes: 9 additions & 0 deletions src/relay/backend/contrib/ethosu/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
String accelerator_config;
bool enable_cascader;
bool enable_striping;
Bool disable_copying_constants = Bool(false);
String dev_force_block_config;
String dev_max_open_plans;
String dev_max_closed_plans;
Expand All @@ -59,6 +60,14 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
TVM_ATTR_FIELD(enable_cascader)
.describe("Whether the cascader should be enabled")
.set_default(false);
TVM_ATTR_FIELD(disable_copying_constants)
.describe(
"Whether copying constants is disabled for case without the cascader. When this option "
"is "
"enabled, it is assumed that the constants should be located in SRAM (user determines "
"in "
"the linker script for section \".rodata.tvm\" that the constants are located in SRAM)")
.set_default(Bool(false));
TVM_ATTR_FIELD(enable_striping)
.describe("Whether the cascader should be striping")
.set_default(false);
Expand Down
16 changes: 16 additions & 0 deletions tests/python/contrib/test_ethosu/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ def test_copy_constants():
assert ".global" in sch.stages[19].op.name


def test_no_copy_constants():
ifm_a = relay.var("IFM_A", shape=(1, 26, 26, 32), dtype="int8")
conv_a = make_ethosu_conv2d(ifm_a, 32, 8, (3, 3), (0, 0), (1, 1), (1, 1))
conv_b = make_ethosu_conv2d(conv_a, 8, 4, (1, 1), (0, 0), (1, 1), (1, 1))
func = relay.Function(relay.analysis.free_vars(conv_b), conv_b)
func = run_opt_pass(func, relay.transform.InferType())

func, _ = extract_constants(func)
cached_func = lower_to_te(func)

sch = te.create_schedule([cached_func.outputs[0].op])
assert len(sch.stages) == 19
ops_names = [x.op.name for x in sch.stages]
assert all(".global" not in x for x in ops_names)


# This test makes sure that constants and LUTs have a correct storage scope
def test_copy_luts():
ifm_shape = (1, 33, 33, 11)
Expand Down