Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ class ObjectPtr {
ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief nullptr check
* \return result of comparison of internal pointer with nullptr.
*/
explicit operator bool() const { return get() != nullptr; }
/*! \brief reset the content of ptr to be nullptr */
void reset() {
if (data_ != nullptr) {
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class VirtualMachine : public runtime::ModuleNode {
* \brief load the executable for the virtual machine.
* \param exec The executable.
*/
virtual void LoadExecutable(Executable* exec);
virtual void LoadExecutable(const ObjectPtr<Executable>& exec);

protected:
/*! \brief Push a call frame on to the call stack. */
Expand Down Expand Up @@ -300,7 +300,7 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The special return register. */
ObjectRef return_register_;
/*! \brief The executable the VM will operate on. */
Executable* exec_;
ObjectPtr<Executable> exec_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you get Executable an ObjectRef wrapper like the rest of the code base you can then use the down casting machinery built on-top of ObjectRef. It might make working with this type easier.

Copy link
Contributor Author

@vvchernov vvchernov Jan 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have gotten idea, I will try

/*! \brief The function name to inputs mapping. */
std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
/*!
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
} else if (name == "vm_load_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(this);
ICHECK(sptr_to_self.get() == this);
vm->LoadExecutable(GetObjectPtr<Executable>(this));
*rv = Module(vm);
});
} else if (name == "move_late_bound_consts") {
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/vm/profiler/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name,
}
}

void VirtualMachineDebug::LoadExecutable(Executable* exec) {
void VirtualMachineDebug::LoadExecutable(const ObjectPtr<Executable>& exec) {
VirtualMachine::LoadExecutable(exec);
ICHECK(exec_);
for (auto kv : exec_->primitive_map) {
packed_index_map_[kv.second] = kv.first;
}
Expand Down Expand Up @@ -204,15 +203,13 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun

runtime::Module CreateVirtualMachineDebug(Executable* exec) {
auto vm = make_object<VirtualMachineDebug>();
vm->LoadExecutable(exec);
vm->LoadExecutable(GetObjectPtr<Executable>(exec));
return runtime::Module(vm);
}

TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
auto* exec = dynamic_cast<Executable*>(mod.operator->());
ICHECK(exec) << "Virtual machine has not been defined yet."
<< "\n";
*rv = CreateVirtualMachineDebug(exec);
});

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vm/profiler/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class VirtualMachineDebug : public VirtualMachine {

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

void LoadExecutable(Executable* exec) final;
void LoadExecutable(const ObjectPtr<Executable>& exec) final;

~VirtualMachineDebug() {}

Expand Down
7 changes: 3 additions & 4 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,15 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
}
}

void VirtualMachine::LoadExecutable(Executable* exec) {
void VirtualMachine::LoadExecutable(const ObjectPtr<Executable>& exec) {
ICHECK(exec) << "The executable is not created yet.";
ICHECK(exec->late_bound_constant_names.empty())
<< "Need to load late-bound-constants before creating VM";
exec_ = exec;

runtime::Module lib = exec_->GetLib();

ICHECK(exec->primitive_map.empty() || lib.operator->())
ICHECK(exec_->primitive_map.empty() || lib.operator->())
<< "If the executable has declared primitive functions, the "
<< "generated kernel library must non-be null.";

Expand Down Expand Up @@ -769,14 +769,13 @@ void VirtualMachine::RunLoop() {

runtime::Module CreateVirtualMachine(Executable* exec) {
auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(exec);
vm->LoadExecutable(GetObjectPtr<Executable>(exec));
return runtime::Module(vm);
}

TVM_REGISTER_GLOBAL("runtime._VirtualMachine").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
auto* exec = dynamic_cast<Executable*>(mod.operator->());
ICHECK(exec) << "The virtual machine executable has not been defined yet.";
*rv = CreateVirtualMachine(exec);
});

Expand Down