From acc6b0b2496827036bacc9ed59a03bb6478ef59c Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Fri, 28 Feb 2025 11:09:43 +0100 Subject: [PATCH 1/2] updated the assert in BindParams to allow tvm.relax.Constant in the input dictionary --- python/tvm/relax/transform/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 603211b59ebc..ff34477dfdb3 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -659,8 +659,8 @@ def BindParams( if isinstance(v, np.ndarray): v = tvm.nd.array(v) assert isinstance( - v, tvm.runtime.NDArray - ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}" + v, (tvm.runtime.NDArray, tvm.relax.Constant) + ), f"param values are expected to be TVM.NDArray, numpy.ndarray or tvm.relax.Constant, but got {type(v)}" tvm_params[k] = v return _ffi_api.BindParams(func_name, tvm_params) # type: ignore From 87c93acc745eb3680569bf36f1fd351ab1f279bc Mon Sep 17 00:00:00 2001 From: Patrik Persson Date: Mon, 3 Mar 2025 10:19:36 +0100 Subject: [PATCH 2/2] fixed linting --- python/tvm/relax/transform/transform.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index ff34477dfdb3..a72439079ef7 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -658,9 +658,10 @@ def BindParams( for k, v in params.items(): if isinstance(v, np.ndarray): v = tvm.nd.array(v) - assert isinstance( - v, (tvm.runtime.NDArray, tvm.relax.Constant) - ), f"param values are expected to be TVM.NDArray, numpy.ndarray or tvm.relax.Constant, but got {type(v)}" + assert isinstance(v, (tvm.runtime.NDArray, tvm.relax.Constant)), ( + f"param values are expected to be TVM.NDArray," + f"numpy.ndarray or tvm.relax.Constant, but got {type(v)}" + ) tvm_params[k] = v return _ffi_api.BindParams(func_name, tvm_params) # type: ignore