diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6d59b858927c..3f72ad5d0475 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -327,7 +327,7 @@ inline relay::Function BindParamsByName( for (auto arg : func->params) { const auto& name = arg->name_hint(); if (name_dict.count(name)) { - repeat_var.insert(arg); + repeat_var.insert(name_dict[name]); } else { name_dict[name] = arg; } diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index b179096a0528..0ab0122fa798 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """ test bind function.""" +import pytest import tvm from tvm import te from tvm import relay +from tvm import TVMError def test_bind_params(): @@ -34,5 +36,16 @@ def test_bind_params(): assert tvm.ir.structural_equal(zbinded, zexpected) +def test_bind_duplicated_params(): + a = relay.var("a", shape=(1,)) + aa = relay.var("a", shape=(1,)) + s = a + aa + func = relay.Function([a, aa], s) + + with pytest.raises(TVMError): + relay.build_module.bind_params_by_name(func, {"a": [1.0]}) + + if __name__ == "__main__": test_bind_params() + test_bind_duplicated_params()