diff --git a/tests/python/relax/test_bind_params.py b/tests/python/relax/test_bind_params.py index a92e4fe8e510..189a44303d6c 100644 --- a/tests/python/relax/test_bind_params.py +++ b/tests/python/relax/test_bind_params.py @@ -111,21 +111,23 @@ def expected() -> R.Shape([16]): prim_value_dtype = tvm.testing.parameter("int64", "int32", "float32") -@pytest.mark.xfail(reason="Depends on relax.PrimValue holding a tir.PrimExpr, PR#15577") def test_bind_prim_value(prim_value_dtype): + N = tir.Var("N", prim_value_dtype) + value = tir.const(16, prim_value_dtype) + @R.function - def before(A: R.Prim(value="N", dtype=prim_value_dtype)): + def before(A: R.Prim(value=N)): R.func_attr({"global_symbol": "main"}) - B: R.Prim(value="N", dtype=prim_value_dtype) = A + B: R.Prim(value=N) = A return B @R.function - def expected() -> R.Prim(value=16, dtype=prim_value_dtype): + def expected() -> R.Prim(value=value): R.func_attr({"global_symbol": "main"}) - B = R.PrimValue(value=16, dtype=dtype) + B = R.prim_value(value) return B - after = before.bind_params({"A": relax.PrimValue(tir.const(16, prim_value_dtype))}) + after = before.bind_params({"A": relax.PrimValue(value)}) tvm.ir.assert_structural_equal(expected, after)