Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/contrib/debugger/debug_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_debug_result(self, sort_by_time=True):
continue
name = node["name"]
shape = str(self._output_tensor_list[eid].shape)
time_us = round(time[0] * 1000000, 3)
time_us = round(time[0] * 1e6, 3)
time_percent = round(((time[0] / total_time) * 100), 3)
inputs = str(node["attrs"]["num_inputs"])
outputs = str(node["attrs"]["num_outputs"])
Expand All @@ -224,8 +224,8 @@ def get_debug_result(self, sort_by_time=True):
# Sort on the basis of execution time. Prints the most expensive ops in the start.
data = sorted(data, key=lambda x: x[2], reverse=True)
# Insert a row for total time at the end.
rounded_total_time = round(total_time * 1000000, 3)
data.append(["Total_time", "-", rounded_total_time, "-", "-", "-", "-", "-"])
rounded_total_time_us = round(total_time * 1e6, 3)
data.append(["Total_time", "-", rounded_total_time_us, "-", "-", "-", "-", "-"])

fmt = ""
for i, _ in enumerate(header):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _run_debug(self):
Time consumed for each execution will be set as debug output.

"""
self.debug_datum._time_list = [[float(t) * 1e-6] for t in self.run_individual(10, 1, 1)]
self.debug_datum._time_list = [[float(t)] for t in self.run_individual(10, 1, 1)]
for i, node in enumerate(self.debug_datum.get_graph_nodes()):
num_outputs = self.debug_datum.get_graph_node_output_num(node)
for j in range(num_outputs):
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Support infra of TVM."""
import ctypes
import tvm._ffi
from .runtime.module import Module
from . import get_global_func


def libinfo():
Expand All @@ -29,4 +32,26 @@ def libinfo():
return {k: v for k, v in GetLibInfo().items()} # pylint: disable=unnecessary-comprehension


class FrontendTestModule(Module):
"""A tvm.runtime.Module whose member functions are PackedFunc."""

def __init__(self, entry_name=None):
underlying_mod = get_global_func("testing.FrontendTestModule")()
handle = underlying_mod.handle

# Set handle to NULL to avoid cleanup in c++ runtime, transferring ownership.
# Both cython and ctypes FFI use c_void_p, so this is safe to assign here.
underlying_mod.handle = ctypes.c_void_p(0)

super(FrontendTestModule, self).__init__(handle)
if entry_name is not None:
self.entry_name = entry_name

def add_function(self, name, func):
self.get_function("__add_function")(name, func)

def __setitem__(self, key, value):
self.add_function(key, value)


tvm._ffi._init_api("support", __name__)
20 changes: 10 additions & 10 deletions src/runtime/graph/debug/graph_runtime_debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,19 @@ class GraphRuntimeDebug : public GraphRuntime {
// warmup run
GraphRuntime::Run();
std::string tkey = module_->type_key();
std::vector<double> time_per_op(op_execs_.size(), 0);
std::vector<double> time_sec_per_op(op_execs_.size(), 0);
if (tkey == "rpc") {
// RPC modules rely on remote timing which implements the logic from the else branch.
for (size_t index = 0; index < op_execs_.size(); ++index) {
time_per_op[index] += RunOpRPC(index, number, repeat, min_repeat_ms);
time_sec_per_op[index] += RunOpRPC(index, number, repeat, min_repeat_ms);
}
} else {
for (int i = 0; i < repeat; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock, std::chrono::nanoseconds>
tbegin, tend;
double duration_ms = 0.0;
do {
std::fill(time_per_op.begin(), time_per_op.end(), 0);
std::fill(time_sec_per_op.begin(), time_sec_per_op.end(), 0);
if (duration_ms > 0.0) {
number = static_cast<int>(std::max((min_repeat_ms / (duration_ms / number) + 1),
number * 1.618)); // 1.618 is chosen by random
Expand All @@ -80,7 +80,7 @@ class GraphRuntimeDebug : public GraphRuntime {
for (int k = 0; k < number; k++) {
for (size_t index = 0; index < op_execs_.size(); ++index) {
if (op_execs_[index]) {
time_per_op[index] += RunOpHost(index);
time_sec_per_op[index] += RunOpHost(index);
}
}
}
Expand All @@ -92,19 +92,19 @@ class GraphRuntimeDebug : public GraphRuntime {

LOG(INFO) << "Iteration: " << i;
int op = 0;
for (size_t index = 0; index < time_per_op.size(); index++) {
for (size_t index = 0; index < time_sec_per_op.size(); index++) {
if (op_execs_[index]) {
time_per_op[index] /= number;
LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " << time_per_op[index]
<< " us/iter";
time_sec_per_op[index] /= number;
LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": "
<< time_sec_per_op[index] * 1e6 << " us/iter";
}
}
}
}

std::ostringstream os;
for (size_t index = 0; index < time_per_op.size(); index++) {
os << time_per_op[index] << ",";
for (size_t index = 0; index < time_sec_per_op.size(); index++) {
os << time_sec_per_op[index] << ",";
}
return os.str();
}
Expand Down
42 changes: 42 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>
Expand Down Expand Up @@ -99,4 +100,45 @@ TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
// and get another value.
*ret = (obj.use_count() - 1);
});

class FrontendTestModuleNode : public runtime::ModuleNode {
public:
virtual const char* type_key() const { return "frontend_test"; }

static constexpr const char* kAddFunctionName = "__add_function";

virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

private:
std::unordered_map<std::string, PackedFunc> functions_;
};

constexpr const char* FrontendTestModuleNode::kAddFunctionName;

PackedFunc FrontendTestModuleNode::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
if (name == kAddFunctionName) {
return TypedPackedFunc<void(std::string, PackedFunc)>(
[this, sptr_to_self](std::string func_name, PackedFunc pf) {
CHECK_NE(func_name, kAddFunctionName)
<< "func_name: cannot be special function " << kAddFunctionName;
functions_[func_name] = pf;
});
}

auto it = functions_.find(name);
if (it == functions_.end()) {
return PackedFunc();
}

return it->second;
}

runtime::Module NewFrontendTestModule() {
auto n = make_object<FrontendTestModuleNode>();
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("testing.FrontendTestModule").set_body_typed(NewFrontendTestModule);

} // namespace tvm
54 changes: 49 additions & 5 deletions tests/python/unittest/test_runtime_graph_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@
# under the License.
import json
import os
import re
import sys
import time

import pytest

import tvm
import tvm.testing
from tvm import te
import numpy as np
from tvm import rpc
from tvm.contrib import utils
from tvm.contrib.debugger import debug_runtime as graph_runtime
from tvm.contrib.debugger import debug_runtime


@tvm.testing.requires_llvm
Expand Down Expand Up @@ -60,8 +66,16 @@ def test_graph_simple():

def check_verify():
mlib = tvm.build(s, [A, B], "llvm", name="myadd")

def myadd(*args):
to_return = mlib["myadd"](*args)
time.sleep(0.25)
return to_return

mlib_proxy = tvm.support.FrontendTestModule()
mlib_proxy["myadd"] = myadd
try:
mod = graph_runtime.create(graph, mlib, tvm.cpu(0))
mod = debug_runtime.create(graph, mlib_proxy, tvm.cpu(0))
except ValueError:
return

Expand Down Expand Up @@ -92,6 +106,36 @@ def check_verify():
# Verify the tensors are dumped
assert len(os.listdir(directory)) > 1

debug_lines = mod.debug_datum.get_debug_result().split("\n")

def split_debug_line(i):
to_return = re.split(r" [ ]*", debug_lines[i])
assert to_return[-1] == ""
to_return = to_return[:-1] # strip empty trailing part
return to_return

assert split_debug_line(0) == [
"Node Name",
"Ops",
"Time(us)",
"Time(%)",
"Shape",
"Inputs",
"Outputs",
]
myadd_lines = split_debug_line(2)
assert myadd_lines[0] == "add"
assert myadd_lines[1] == "myadd"
runtime_sec = float(myadd_lines[2]) / 1e6 # printed in us

# Ensure runtime is at least the sleep time and less than a unit prefix order of magnitude.
# Here we just care that the prefix is correct.
assert runtime_sec > 0.25 and runtime_sec < 0.25 * 1000

total_lines = split_debug_line(3)
assert total_lines[0] == "Total_time"
assert total_lines[2] == myadd_lines[2]

CHROME_TRACE_FILE_NAME = "_tvmdbg_execution_trace.json"
assert os.path.exists(os.path.join(directory, CHROME_TRACE_FILE_NAME))

Expand Down Expand Up @@ -127,9 +171,9 @@ def check_remote():
remote.upload(path_dso)
mlib = remote.load_module("dev_lib.so")
try:
mod = graph_runtime.create(graph, mlib, remote.cpu(0))
mod = debug_runtime.create(graph, mlib, remote.cpu(0))
except ValueError:
print("Skip because debug graph_runtime not enabled")
print("Skip because debug runtime not enabled")
return
a = np.random.uniform(size=(n,)).astype(A.dtype)
mod.run(x=tvm.nd.array(a, ctx))
Expand All @@ -142,4 +186,4 @@ def check_remote():


if __name__ == "__main__":
test_graph_simple()
sys.exit(pytest.main([__file__] + sys.argv[1:]))