diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0ed61177e65a..7046afcaf482 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -438,6 +438,11 @@ class ObjectPtr { ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) return *this; } + /*! + * \brief nullptr check + * \return result of comparison of internal pointer with nullptr. + */ + explicit operator bool() const { return get() != nullptr; } /*! \brief reset the content of ptr to be nullptr */ void reset() { if (data_ != nullptr) { diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 67c21a1b479f..d7311951b702 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -174,7 +174,7 @@ class VirtualMachine : public runtime::ModuleNode { * \brief load the executable for the virtual machine. * \param exec The executable. */ - virtual void LoadExecutable(Executable* exec); + virtual void LoadExecutable(const ObjectPtr& exec); protected: /*! \brief Push a call frame on to the call stack. */ @@ -300,7 +300,7 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The special return register. */ ObjectRef return_register_; /*! \brief The executable the VM will operate on. */ - Executable* exec_; + ObjectPtr exec_; /*! \brief The function name to inputs mapping. */ std::unordered_map> inputs_; /*! diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index e2fe867630b0..0e2246835995 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -86,7 +86,8 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr(); - vm->LoadExecutable(this); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr(this)); *rv = Module(vm); }); } else if (name == "move_late_bound_consts") { diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 67344df7dbe6..bcbd9011a1df 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -90,9 +90,8 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, } } -void VirtualMachineDebug::LoadExecutable(Executable* exec) { +void VirtualMachineDebug::LoadExecutable(const ObjectPtr& exec) { VirtualMachine::LoadExecutable(exec); - ICHECK(exec_); for (auto kv : exec_->primitive_map) { packed_index_map_[kv.second] = kv.first; } @@ -204,15 +203,13 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun runtime::Module CreateVirtualMachineDebug(Executable* exec) { auto vm = make_object(); - vm->LoadExecutable(exec); + vm->LoadExecutable(GetObjectPtr(exec)); return runtime::Module(vm); } TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; auto* exec = dynamic_cast(mod.operator->()); - ICHECK(exec) << "Virtual machine has not been defined yet." - << "\n"; *rv = CreateVirtualMachineDebug(exec); }); diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 4a09b51fb86e..0c9e94c0ddd6 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -44,7 +44,7 @@ class VirtualMachineDebug : public VirtualMachine { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void LoadExecutable(Executable* exec) final; + void LoadExecutable(const ObjectPtr& exec) final; ~VirtualMachineDebug() {} diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 7a83c9acb906..b1b5068ee19e 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -374,7 +374,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In } } -void VirtualMachine::LoadExecutable(Executable* exec) { +void VirtualMachine::LoadExecutable(const ObjectPtr& exec) { ICHECK(exec) << "The executable is not created yet."; ICHECK(exec->late_bound_constant_names.empty()) << "Need to load late-bound-constants before creating VM"; @@ -382,7 +382,7 @@ void VirtualMachine::LoadExecutable(Executable* exec) { runtime::Module lib = exec_->GetLib(); - ICHECK(exec->primitive_map.empty() || lib.operator->()) + ICHECK(exec_->primitive_map.empty() || lib.operator->()) << "If the executable has declared primitive functions, the " << "generated kernel library must non-be null."; @@ -769,14 +769,13 @@ void VirtualMachine::RunLoop() { runtime::Module CreateVirtualMachine(Executable* exec) { auto vm = make_object(); - vm->LoadExecutable(exec); + vm->LoadExecutable(GetObjectPtr(exec)); return runtime::Module(vm); } TVM_REGISTER_GLOBAL("runtime._VirtualMachine").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; auto* exec = dynamic_cast(mod.operator->()); - ICHECK(exec) << "The virtual machine executable has not been defined yet."; *rv = CreateVirtualMachine(exec); });