Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/runtime/profiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ class Report : public ObjectRef {
*/
explicit Report(Array<Map<String, ObjectRef>> calls,
Map<String, Map<String, ObjectRef>> device_metrics);

/*! Deserialize a Report from a JSON object. Needed for sending the report over RPC.
* \param json Serialized json report from `ReportNode::AsJSON`.
* \returns A Report.
*/
static Report FromJSON(String json);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Report, ObjectRef, ReportNode);
};

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/contrib/debugger/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm._ffi.base import string_types
from tvm.contrib import graph_executor
from . import debug_result
from ...runtime.profiling import Report

_DUMP_ROOT_PREFIX = "tvmdbg_"
_DUMP_PATH_PREFIX = "_tvmdbg_"
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, module, device, graph_json_str, dump_root):
self._execute_node = module["execute_node"]
self._get_node_output = module["get_node_output"]
self._profile = module["profile"]
self._profile_rpc = module["profile_rpc"]
graph_executor.GraphModule.__init__(self, module)
self._create_debug_env(graph_json_str, device)

Expand Down Expand Up @@ -274,7 +276,7 @@ def profile(self, collectors=None, **input_dict):
Parameters
----------
collectors : Optional[Sequence[MetricCollector]]
Extra metrics to collect.
Extra metrics to collect. If profiling over RPC, collectors must be `None`.

input_dict : dict of str to NDArray
List of input values to be feed to
Expand All @@ -284,10 +286,13 @@ def profile(self, collectors=None, **input_dict):
timing_results : str
Per-operator and whole graph timing results in a table format.
"""
collectors = [] if collectors is None else collectors
if input_dict:
self.set_input(**input_dict)

if self.module.type_key == "rpc":
# We cannot serialize MetricCollectors over RPC
assert collectors is None, "Profiling with collectors is not supported over RPC"
return Report.from_json(self._profile_rpc())
return self._profile(collectors)

def exit(self):
Expand Down
19 changes: 16 additions & 3 deletions python/tvm/runtime/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
"""
import warnings
from tvm.runtime import _ffi_api
from tvm.rpc import base as rpc_base
from . import vm
from .profiling import Report


def enabled():
Expand All @@ -35,10 +37,18 @@ class VirtualMachineProfiler(vm.VirtualMachine):

def __init__(self, exe, device, memory_cfg=None):
super(VirtualMachineProfiler, self).__init__(exe, device, memory_cfg)
self.module = _ffi_api._VirtualMachineDebug(exe.module)

# Make sure the constructor of the VM module is on the proper device
# Remote devices have device_type of their actual device_type + RPC_SESS_MASK
if device.device_type >= rpc_base.RPC_SESS_MASK:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking device type using a >= is kind of confusing. If there isn't a better way to check can you explain why the >= works in a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.module = device._rpc_sess.get_function("runtime._VirtualMachineDebug")(exe)
else:
self.module = _ffi_api._VirtualMachineDebug(exe.module)

self._init = self.module["init"]
self._invoke = self.module["invoke"]
self._profile = self.module["profile"]
self._profile_rpc = self.module["profile_rpc"]
self._set_input = self.module["set_input"]
self._setup_device(device, memory_cfg)

Expand All @@ -59,7 +69,7 @@ def profile(self, *args, func_name="main", collectors=None, **kwargs):
The name of the function.

collectors : Optional[Sequence[MetricCollector]]
Extra metrics to collect.
Extra metrics to collect. If profiling over RPC, collectors must be `None`.

args : list[tvm.runtime.NDArray] or list[np.ndarray]
The arguments to the function.
Expand All @@ -72,7 +82,10 @@ def profile(self, *args, func_name="main", collectors=None, **kwargs):
timing_results : str
Overall and per-op timing results formatted in a table.
"""
collectors = [] if collectors is None else collectors
if args or kwargs:
self.set_input(func_name, *args, **kwargs)
if self.module.type_key == "rpc":
# We cannot serialize MetricCollectors over RPC
assert collectors is None, "Profiling with collectors is not supported over RPC"
return Report.from_json(self._profile_rpc(func_name))
return self._profile(func_name, collectors)
16 changes: 16 additions & 0 deletions python/tvm/runtime/profiling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ def json(self):
"""
return _ffi_api.AsJSON(self)

@classmethod
def from_json(cls, s):
"""Deserialize a report from JSON.

Parameters
----------
s : str
Report serialize via :py:meth:`json`.

Returns
-------
report : Report
The deserialized report.
"""
return _ffi_api.FromJSON(s)


@_ffi.register_object("runtime.profiling.MetricCollector")
class MetricCollector(Object):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def _convert(arg, cargs):
dtype = "int32" if isinstance(arg, (_base.integer_types, bool)) else "float32"
value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0))
cargs.append(value)
elif isinstance(arg, str):
cargs.append(arg)
else:
raise TypeError("Unsupported type: %s" % (type(arg)))

Expand Down
17 changes: 16 additions & 1 deletion src/runtime/graph_executor/debug/graph_executor_debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,23 @@ PackedFunc GraphExecutorDebug::GetFunction(const std::string& name,
} else if (name == "profile") {
return TypedPackedFunc<profiling::Report(Array<profiling::MetricCollector>)>(
[sptr_to_self, this](Array<profiling::MetricCollector> collectors) {
return this->Profile(collectors);
// We cannot send Arrays over rpc, so in order to support profiling
// on remotes, we accept a nullptr for collectors.
if (collectors.defined()) {
return this->Profile(collectors);
} else {
return this->Profile({});
}
});
} else if (name == "profile_rpc") {
// We cannot return a Report over RPC because TMV RPC mechanism only
// supports a subset of Object classes. Instead we serialize it on the
// remote (here) and deserialize it on the other end.
return TypedPackedFunc<std::string()>([sptr_to_self, this]() {
PackedFunc profile = GetFunction("profile", sptr_to_self);
profiling::Report report = profile(Array<profiling::MetricCollector>());
return report->AsJSON();
});
} else {
return GraphExecutor::GetFunction(name, sptr_to_self);
}
Expand Down
72 changes: 71 additions & 1 deletion src/runtime/profiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Runtime profiling including timers.
*/

#include <dmlc/json.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/c_backend_api.h>
#include <tvm/runtime/packed_func.h>
Expand Down Expand Up @@ -231,7 +232,9 @@ String ReportNode::AsCSV() const {
namespace {
void print_metric(std::ostream& os, ObjectRef o) {
if (o.as<StringObj>()) {
os << "\"" << Downcast<String>(o) << "\"";
os << "{\"string\":"
<< "\"" << Downcast<String>(o) << "\""
<< "}";
} else if (const CountNode* n = o.as<CountNode>()) {
os << "{\"count\":" << std::to_string(n->value) << "}";
} else if (const DurationNode* n = o.as<DurationNode>()) {
Expand Down Expand Up @@ -540,6 +543,72 @@ Report::Report(Array<Map<String, ObjectRef>> calls,
data_ = std::move(node);
}

Map<String, ObjectRef> parse_metrics(dmlc::JSONReader* reader) {
reader->BeginObject();
std::string metric_name, metric_value_name;
Map<String, ObjectRef> metrics;
while (reader->NextObjectItem(&metric_name)) {
ObjectRef o;
reader->BeginObject();
reader->NextObjectItem(&metric_value_name);
if (metric_value_name == "microseconds") {
double microseconds;
reader->Read(&microseconds);
o = ObjectRef(make_object<DurationNode>(microseconds));
} else if (metric_value_name == "percent") {
double percent;
reader->Read(&percent);
o = ObjectRef(make_object<PercentNode>(percent));
} else if (metric_value_name == "count") {
int64_t count;
reader->Read(&count);
o = ObjectRef(make_object<CountNode>(count));
} else if (metric_value_name == "string") {
std::string s;
reader->Read(&s);
o = String(s);
} else {
LOG(FATAL) << "Cannot parse metric of type " << metric_value_name
<< " valid types are microseconds, percent, count.";
}
metrics.Set(metric_name, o);
// Necessary to make sure that the parser hits the end of the object.
ICHECK(!reader->NextObjectItem(&metric_value_name));
// EndObject does not exist, leaving this here for clarity
// reader.EndObject();
}
// reader.EndObject();
return metrics;
}

Report Report::FromJSON(String json) {
std::stringstream input(json.operator std::string());
dmlc::JSONReader reader(&input);
std::string key;
Array<Map<String, ObjectRef>> calls;
Map<String, Map<String, ObjectRef>> device_metrics;

reader.BeginObject();
while (reader.NextObjectItem(&key)) {
if (key == "calls") {
reader.BeginArray();
while (reader.NextArrayItem()) {
calls.push_back(parse_metrics(&reader));
}
// reader.EndArray();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these commented out lines be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like leaving these lines in as it makes it clear where the end of each Array/Object is.

} else if (key == "device_metrics") {
reader.BeginObject();
std::string device_name;
while (reader.NextObjectItem(&device_name)) {
device_metrics.Set(device_name, parse_metrics(&reader));
}
// reader.EndObject();
}
}

return Report(calls, device_metrics);
}

TVM_REGISTER_OBJECT_TYPE(DurationNode);
TVM_REGISTER_OBJECT_TYPE(PercentNode);
TVM_REGISTER_OBJECT_TYPE(CountNode);
Expand All @@ -551,6 +620,7 @@ TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { ret
TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) {
return n->AsJSON();
});
TVM_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON);
TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) {
return DeviceWrapper(dev);
});
Expand Down
19 changes: 17 additions & 2 deletions src/runtime/vm/profiler/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name,
}
}

std::vector<profiling::MetricCollector> cs(collectors.begin(), collectors.end());
prof_ = profiling::Profiler(devices, cs);
// We cannot send Arrays over rpc, so in order to support profiling
// on remotes, we accept a nullptr for collectors.
if (collectors.defined()) {
std::vector<profiling::MetricCollector> cs(collectors.begin(), collectors.end());
prof_ = profiling::Profiler(devices, cs);
} else {
prof_ = profiling::Profiler(devices, {});
}

auto invoke = VirtualMachine::GetFunction("invoke", sptr_to_self);
// warmup
Expand All @@ -68,6 +74,15 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name,
prof_ = dmlc::optional<profiling::Profiler>(); // releases hardware counters
return report;
});
} else if (name == "profile_rpc") {
// We cannot return a Report over RPC because TMV RPC mechanism only
// supports a subset of Object classes. Instead we serialize it on the
// remote (here) and deserialize it on the other end.
return TypedPackedFunc<std::string(std::string)>([sptr_to_self, this](std::string arg_name) {
PackedFunc profile = GetFunction("profile", sptr_to_self);
profiling::Report report = profile(arg_name, Array<profiling::MetricCollector>());
return report->AsJSON();
});
} else {
return VirtualMachine::GetFunction(name, sptr_to_self);
}
Expand Down
56 changes: 54 additions & 2 deletions tests/python/unittest/test_runtime_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from tvm import relay
from tvm.relay.testing import mlp
from tvm.contrib.debugger import debug_executor
from tvm import rpc
from tvm.contrib import utils
from tvm.runtime.profiling import Report


def read_csv(report):
Expand Down Expand Up @@ -102,7 +105,6 @@ def test_papi(target, dev):
func_name="main",
collectors=[tvm.runtime.profiling.PAPIMetricCollector({dev: [metric]})],
)
print(report)
assert metric in str(report)

csv = read_csv(report)
Expand All @@ -126,10 +128,60 @@ def test_json():
assert "microseconds" in parsed["calls"][0]["Duration (us)"]
assert len(parsed["calls"]) > 0
for call in parsed["calls"]:
assert isinstance(call["Name"], str)
assert isinstance(call["Name"]["string"], str)
assert isinstance(call["Count"]["count"], int)
assert isinstance(call["Duration (us)"]["microseconds"], float)


@tvm.testing.requires_llvm
def test_rpc_vm():
server = rpc.Server(key="profiling")
remote = rpc.connect("127.0.0.1", server.port, key="profiling")

mod, params = mlp.get_workload(1)
exe = relay.vm.compile(mod, "llvm", params=params)
temp = utils.tempdir()
path = temp.relpath("lib.tar")
exe.mod.export_library(path)
remote.upload(path)
rexec = remote.load_module("lib.tar")
vm = profiler_vm.VirtualMachineProfiler(rexec, remote.cpu())
report = vm.profile(tvm.nd.array(np.ones((1, 1, 28, 28), dtype="float32"), device=remote.cpu()))
assert len(report.calls) > 0


def test_rpc_graph():
server = rpc.Server(key="profiling")
remote = rpc.connect("127.0.0.1", server.port, key="profiling")

mod, params = mlp.get_workload(1)
exe = relay.build(mod, "llvm", params=params)
temp = utils.tempdir()
path = temp.relpath("lib.tar")
exe.export_library(path)
remote.upload(path)
rexec = remote.load_module("lib.tar")

gr = debug_executor.create(exe.get_graph_json(), rexec, remote.cpu())

data = np.random.rand(1, 1, 28, 28).astype("float32")
report = gr.profile(data=data)
assert len(report.calls) > 0


def test_report_serialization():
mod, params = mlp.get_workload(1)

exe = relay.vm.compile(mod, "llvm", params=params)
vm = profiler_vm.VirtualMachineProfiler(exe, tvm.cpu())

data = np.random.rand(1, 1, 28, 28).astype("float32")
report = vm.profile(data, func_name="main")

report2 = Report.from_json(report.json())
# equality on reports compares pointers, so we compare the printed results instead.
assert str(report) == str(report2)


if __name__ == "__main__":
test_papi("llvm", tvm.cpu())