diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d5c017502426..1b3ad571dbaa 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -513,8 +513,11 @@ class TVMArgValue : public TVMPODValue_ { } } operator tvm::runtime::String() const { - // directly use the std::string constructor for now. - return tvm::runtime::String(operator std::string()); + if (IsObjectRef()) { + return AsObjectRef(); + } else { + return tvm::runtime::String(operator std::string()); + } } operator DLDataType() const { if (type_code_ == kTVMStr) { @@ -605,8 +608,11 @@ class TVMRetValue : public TVMPODValue_ { return *ptr(); } operator tvm::runtime::String() const { - // directly use the std::string constructor for now. - return tvm::runtime::String(operator std::string()); + if (IsObjectRef()) { + return AsObjectRef(); + } else { + return tvm::runtime::String(operator std::string()); + } } operator DLDataType() const { if (type_code_ == kTVMStr) { diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 4a815ffd5d7d..d0313c60d984 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -95,6 +95,12 @@ TEST(PackedFunc, str) { CHECK(y == "hello"); *rv = x; })("hello"); + + PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK(args.num_args == 1); + runtime::String s = args[0]; + CHECK(s == "hello"); + })(runtime::String("hello")); }