From 66d59801a72b218ad415b77fed9c5c066218c1eb Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 22 Oct 2021 02:07:23 -0400 Subject: [PATCH 1/3] [Fixbug] Report duplicated param names of relay function when bind params --- src/relay/backend/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6d59b858927c..3f72ad5d0475 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -327,7 +327,7 @@ inline relay::Function BindParamsByName( for (auto arg : func->params) { const auto& name = arg->name_hint(); if (name_dict.count(name)) { - repeat_var.insert(arg); + repeat_var.insert(name_dict[name]); } else { name_dict[name] = arg; } From b41957c38bb11f6167fb35091b869483c0b76ad1 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 23 Oct 2021 01:17:13 -0400 Subject: [PATCH 2/3] add test --- tests/python/relay/test_ir_bind.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index b179096a0528..a48bc856d6df 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """ test bind function.""" +import pytest import tvm from tvm import te from tvm import relay +from tvm import TVMError def test_bind_params(): @@ -34,5 +36,16 @@ def test_bind_params(): assert tvm.ir.structural_equal(zbinded, zexpected) +def test_bind_duplicated_params(): + a = relay.var('a', shape=(1,)) + aa = relay.var('a', shape=(1,)) + s = a + aa + func = relay.Function([a, aa], s) + + with pytest.raises(TVMError): + relay.build_module.bind_params_by_name(func, {'a': [1.0]}) + + if __name__ == "__main__": test_bind_params() + test_bind_duplicated_params() From 85bfdde31a546f302536eb667b4b2bb807eca5a8 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 23 Oct 2021 14:44:27 -0400 Subject: [PATCH 3/3] lint --- tests/python/relay/test_ir_bind.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index a48bc856d6df..0ab0122fa798 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -37,13 +37,13 @@ def test_bind_params(): def test_bind_duplicated_params(): - a = relay.var('a', shape=(1,)) - aa = relay.var('a', shape=(1,)) + a = relay.var("a", shape=(1,)) + aa = relay.var("a", shape=(1,)) s = a + aa func = relay.Function([a, aa], s) with pytest.raises(TVMError): - relay.build_module.bind_params_by_name(func, {'a': [1.0]}) + relay.build_module.bind_params_by_name(func, {"a": [1.0]}) if __name__ == "__main__":