diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 090bcf01b5a5..f4b272979bb6 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -531,6 +531,11 @@ class VMShapeLowerMutator // the shape_func to indicate that this is a host function // This could require us to attach target to the relax function here. tir::PrimFunc shape_func(params, body, ret_type, buffer_map); + if (shape_func->attrs.GetAttr(tvm::attr::kTarget) == nullptr) { + // kTarget and kIsHostFunc are mutually exclusive + shape_func = + WithAttr(std::move(shape_func), tvm::tir::attr::kIsHostFunc, Integer(1)); + } GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func"); builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); return to_compute.size(); diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 5cd104dd013f..9c11b352c831 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -178,6 +178,7 @@ class Expected: @T.prim_func def shape_func(H: T.Buffer(T.int64(4), "int64")): # generated compute function + T.func_attr({"tir.is_host_func": 1}) H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1) @R.function