From 407ccb62a86473655015102220a45dfb698e8fb0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 Jan 2023 10:14:24 +0900 Subject: [PATCH] [MetaSchedule] Fix for RewriteLayout + AllocateConst when the rank of the rewritten weight doesn't change --- src/relay/backend/te_compiler_cache.cc | 21 +++++- .../test_meta_schedule_relay_integration.py | 74 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 154101fc94fe..c680c5a77e04 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -576,7 +576,26 @@ class ScheduleBuilder : public ExprVisitor { << "Only one layout-free constant is supported by RewriteLayout for now"; auto constant = const_collector.constants[0]; - if (constant.Shape().size() == index_map->initial_indices.size()) { + auto is_constant_transformed = [index_map](runtime::NDArray c) { + if (c.Shape().size() != index_map->initial_indices.size()) { + return true; + } + size_t src_size_1d = 1; + Array orig_shape; + for (size_t i = 0; i < c.Shape().size(); ++i) { + src_size_1d *= c->shape[i]; + orig_shape.push_back(PrimExpr(static_cast((c->shape[i])))); + } + auto dst_shape = index_map->MapShape(orig_shape); + std::vector dst_shape_int; + size_t dst_size_1d = 1; + for (size_t i = 0; i < dst_shape.size(); ++i) { + dst_size_1d *= dst_shape[i].as()->value; + } + return src_size_1d != dst_size_1d; + }; + + if (!is_constant_transformed(constant)) { // This is the first case, reached during the MetaScheduleLayoutRewrite pass. // // A layout-free constant having the same rank as an input to the index map diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 795890de083e..8cd58e5a6f36 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -880,5 +880,79 @@ def test_disabled_pass_param(): pytest.fail("'disabled_pass' argument does not work") +def test_rewrite_layout_link_params_1x1_conv2d(): + I, O, H, W = 32, 16, 256, 256 + kH = kW = 1 + + strides = (1, 1) + padding = (0, 0) + + data_shape = (1, H, W, I) + w_shape = (kH, kW, I, O) + + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=w_shape, dtype="float32") + + conv = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(kH, kW), + channels=O, + padding=padding, + strides=strides, + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="float32", + ) + + mod = tvm.IRModule.from_expr(conv) + + weight_np = np.random.randn(*w_shape).astype("float32") + + params = {"weight": weight_np} + + data_np = np.random.randn(*data_shape).astype("float32") + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np]) + .numpy() + ) + + link_params = True + + target = "llvm --num-cores=4" + + executor = relay.backend.Executor("graph", {"link-params": link_params}) + mod = mod.with_attr("executor", executor) + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=8, + strategy="replay-trace", + ) + + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + ) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) + + if __name__ == "__main__": tvm.testing.main()