diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 9fe2a3a06fb7..716550ba045b 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -45,7 +45,8 @@ class PrimValueComputeInjector : public ExprMutator { auto param_vars = tir::UndefinedVars(node->value); tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), {node->value})); - tir::PrimFunc func(param_vars, body, PrimType(ret_dtype)); + tir::PrimFunc func(param_vars, body, PrimType(ret_dtype), {}, + DictAttrs({{tir::attr::kIsHostFunc, Bool(true)}})); func = tir::RenewDefs(func); auto callee = builder_->AddFunction(func, "compute_symbolic_expr"); diff --git a/tests/python/relax/test_transform_compute_prim_value.py b/tests/python/relax/test_transform_compute_prim_value.py index 9fee35414d0d..5d9caf2d365c 100644 --- a/tests/python/relax/test_transform_compute_prim_value.py +++ b/tests/python/relax/test_transform_compute_prim_value.py @@ -44,6 +44,7 @@ def main(A: R.Tensor(["N"])): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64) -> T.bool: + T.func_attr({"tir.is_host_func": True}) T.ret(N % 16 == 0) @@ -73,6 +74,7 @@ def main(A: R.Tensor(["N"])): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64) -> T.bool: + T.func_attr({"tir.is_host_func": True}) T.ret(N % 16 == 0) @@ -97,6 +99,7 @@ def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) -> R.Prim(value="N*M"): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64: + T.func_attr({"tir.is_host_func": True}) T.ret(N * M)