diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index eea195f64a6d..7ee140622bfc 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -251,6 +251,12 @@ class Report : public ObjectRef { */ explicit Report(Array> calls, Map> device_metrics); + + /*! Deserialize a Report from a JSON object. Needed for sending the report over RPC. + * \param json Serialized json report from `ReportNode::AsJSON`. + * \returns A Report. + */ + static Report FromJSON(String json); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Report, ObjectRef, ReportNode); }; diff --git a/python/tvm/contrib/debugger/debug_executor.py b/python/tvm/contrib/debugger/debug_executor.py index fc3b245d88ad..6ffae08621e0 100644 --- a/python/tvm/contrib/debugger/debug_executor.py +++ b/python/tvm/contrib/debugger/debug_executor.py @@ -25,6 +25,7 @@ from tvm._ffi.base import string_types from tvm.contrib import graph_executor from . import debug_result +from ...runtime.profiling import Report _DUMP_ROOT_PREFIX = "tvmdbg_" _DUMP_PATH_PREFIX = "_tvmdbg_" @@ -102,6 +103,7 @@ def __init__(self, module, device, graph_json_str, dump_root): self._execute_node = module["execute_node"] self._get_node_output = module["get_node_output"] self._profile = module["profile"] + self._profile_rpc = module["profile_rpc"] graph_executor.GraphModule.__init__(self, module) self._create_debug_env(graph_json_str, device) @@ -274,7 +276,7 @@ def profile(self, collectors=None, **input_dict): Parameters ---------- collectors : Optional[Sequence[MetricCollector]] - Extra metrics to collect. + Extra metrics to collect. If profiling over RPC, collectors must be `None`. input_dict : dict of str to NDArray List of input values to be feed to @@ -284,10 +286,13 @@ def profile(self, collectors=None, **input_dict): timing_results : str Per-operator and whole graph timing results in a table format. """ - collectors = [] if collectors is None else collectors if input_dict: self.set_input(**input_dict) + if self.module.type_key == "rpc": + # We cannot serialize MetricCollectors over RPC + assert collectors is None, "Profiling with collectors is not supported over RPC" + return Report.from_json(self._profile_rpc()) return self._profile(collectors) def exit(self): diff --git a/python/tvm/runtime/profiler_vm.py b/python/tvm/runtime/profiler_vm.py index b3043d8b8760..4f625c0c67f1 100644 --- a/python/tvm/runtime/profiler_vm.py +++ b/python/tvm/runtime/profiler_vm.py @@ -22,7 +22,9 @@ """ import warnings from tvm.runtime import _ffi_api +from tvm.rpc import base as rpc_base from . import vm +from .profiling import Report def enabled(): @@ -35,10 +37,18 @@ class VirtualMachineProfiler(vm.VirtualMachine): def __init__(self, exe, device, memory_cfg=None): super(VirtualMachineProfiler, self).__init__(exe, device, memory_cfg) - self.module = _ffi_api._VirtualMachineDebug(exe.module) + + # Make sure the constructor of the VM module is on the proper device + # Remote devices have device_type of their actual device_type + RPC_SESS_MASK + if device.device_type >= rpc_base.RPC_SESS_MASK: + self.module = device._rpc_sess.get_function("runtime._VirtualMachineDebug")(exe) + else: + self.module = _ffi_api._VirtualMachineDebug(exe.module) + self._init = self.module["init"] self._invoke = self.module["invoke"] self._profile = self.module["profile"] + self._profile_rpc = self.module["profile_rpc"] self._set_input = self.module["set_input"] self._setup_device(device, memory_cfg) @@ -59,7 +69,7 @@ def profile(self, *args, func_name="main", collectors=None, **kwargs): The name of the function. collectors : Optional[Sequence[MetricCollector]] - Extra metrics to collect. + Extra metrics to collect. If profiling over RPC, collectors must be `None`. args : list[tvm.runtime.NDArray] or list[np.ndarray] The arguments to the function. @@ -72,7 +82,10 @@ def profile(self, *args, func_name="main", collectors=None, **kwargs): timing_results : str Overall and per-op timing results formatted in a table. """ - collectors = [] if collectors is None else collectors if args or kwargs: self.set_input(func_name, *args, **kwargs) + if self.module.type_key == "rpc": + # We cannot serialize MetricCollectors over RPC + assert collectors is None, "Profiling with collectors is not supported over RPC" + return Report.from_json(self._profile_rpc(func_name)) return self._profile(func_name, collectors) diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 881691609398..b91fe727698b 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -104,6 +104,22 @@ def json(self): """ return _ffi_api.AsJSON(self) + @classmethod + def from_json(cls, s): + """Deserialize a report from JSON. + + Parameters + ---------- + s : str + Report serialize via :py:meth:`json`. + + Returns + ------- + report : Report + The deserialized report. + """ + return _ffi_api.FromJSON(s) + @_ffi.register_object("runtime.profiling.MetricCollector") class MetricCollector(Object): diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 6416ad7814e1..2be3f3ec1a78 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -48,6 +48,8 @@ def _convert(arg, cargs): dtype = "int32" if isinstance(arg, (_base.integer_types, bool)) else "float32" value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) cargs.append(value) + elif isinstance(arg, str): + cargs.append(arg) else: raise TypeError("Unsupported type: %s" % (type(arg))) diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc b/src/runtime/graph_executor/debug/graph_executor_debug.cc index 2fa73971d000..12a739722a5c 100644 --- a/src/runtime/graph_executor/debug/graph_executor_debug.cc +++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc @@ -365,8 +365,23 @@ PackedFunc GraphExecutorDebug::GetFunction(const std::string& name, } else if (name == "profile") { return TypedPackedFunc)>( [sptr_to_self, this](Array collectors) { - return this->Profile(collectors); + // We cannot send Arrays over rpc, so in order to support profiling + // on remotes, we accept a nullptr for collectors. + if (collectors.defined()) { + return this->Profile(collectors); + } else { + return this->Profile({}); + } }); + } else if (name == "profile_rpc") { + // We cannot return a Report over RPC because TMV RPC mechanism only + // supports a subset of Object classes. Instead we serialize it on the + // remote (here) and deserialize it on the other end. + return TypedPackedFunc([sptr_to_self, this]() { + PackedFunc profile = GetFunction("profile", sptr_to_self); + profiling::Report report = profile(Array()); + return report->AsJSON(); + }); } else { return GraphExecutor::GetFunction(name, sptr_to_self); } diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 596b6ace8831..bd59be87f7d9 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -22,6 +22,7 @@ * \brief Runtime profiling including timers. */ +#include #include #include #include @@ -231,7 +232,9 @@ String ReportNode::AsCSV() const { namespace { void print_metric(std::ostream& os, ObjectRef o) { if (o.as()) { - os << "\"" << Downcast(o) << "\""; + os << "{\"string\":" + << "\"" << Downcast(o) << "\"" + << "}"; } else if (const CountNode* n = o.as()) { os << "{\"count\":" << std::to_string(n->value) << "}"; } else if (const DurationNode* n = o.as()) { @@ -540,6 +543,72 @@ Report::Report(Array> calls, data_ = std::move(node); } +Map parse_metrics(dmlc::JSONReader* reader) { + reader->BeginObject(); + std::string metric_name, metric_value_name; + Map metrics; + while (reader->NextObjectItem(&metric_name)) { + ObjectRef o; + reader->BeginObject(); + reader->NextObjectItem(&metric_value_name); + if (metric_value_name == "microseconds") { + double microseconds; + reader->Read(µseconds); + o = ObjectRef(make_object(microseconds)); + } else if (metric_value_name == "percent") { + double percent; + reader->Read(&percent); + o = ObjectRef(make_object(percent)); + } else if (metric_value_name == "count") { + int64_t count; + reader->Read(&count); + o = ObjectRef(make_object(count)); + } else if (metric_value_name == "string") { + std::string s; + reader->Read(&s); + o = String(s); + } else { + LOG(FATAL) << "Cannot parse metric of type " << metric_value_name + << " valid types are microseconds, percent, count."; + } + metrics.Set(metric_name, o); + // Necessary to make sure that the parser hits the end of the object. + ICHECK(!reader->NextObjectItem(&metric_value_name)); + // EndObject does not exist, leaving this here for clarity + // reader.EndObject(); + } + // reader.EndObject(); + return metrics; +} + +Report Report::FromJSON(String json) { + std::stringstream input(json.operator std::string()); + dmlc::JSONReader reader(&input); + std::string key; + Array> calls; + Map> device_metrics; + + reader.BeginObject(); + while (reader.NextObjectItem(&key)) { + if (key == "calls") { + reader.BeginArray(); + while (reader.NextArrayItem()) { + calls.push_back(parse_metrics(&reader)); + } + // reader.EndArray(); + } else if (key == "device_metrics") { + reader.BeginObject(); + std::string device_name; + while (reader.NextObjectItem(&device_name)) { + device_metrics.Set(device_name, parse_metrics(&reader)); + } + // reader.EndObject(); + } + } + + return Report(calls, device_metrics); +} + TVM_REGISTER_OBJECT_TYPE(DurationNode); TVM_REGISTER_OBJECT_TYPE(PercentNode); TVM_REGISTER_OBJECT_TYPE(CountNode); @@ -551,6 +620,7 @@ TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { ret TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { return n->AsJSON(); }); +TVM_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON); TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { return DeviceWrapper(dev); }); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 6d893114d623..d6575c35d10d 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -52,8 +52,14 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, } } - std::vector cs(collectors.begin(), collectors.end()); - prof_ = profiling::Profiler(devices, cs); + // We cannot send Arrays over rpc, so in order to support profiling + // on remotes, we accept a nullptr for collectors. + if (collectors.defined()) { + std::vector cs(collectors.begin(), collectors.end()); + prof_ = profiling::Profiler(devices, cs); + } else { + prof_ = profiling::Profiler(devices, {}); + } auto invoke = VirtualMachine::GetFunction("invoke", sptr_to_self); // warmup @@ -68,6 +74,15 @@ PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, prof_ = dmlc::optional(); // releases hardware counters return report; }); + } else if (name == "profile_rpc") { + // We cannot return a Report over RPC because TMV RPC mechanism only + // supports a subset of Object classes. Instead we serialize it on the + // remote (here) and deserialize it on the other end. + return TypedPackedFunc([sptr_to_self, this](std::string arg_name) { + PackedFunc profile = GetFunction("profile", sptr_to_self); + profiling::Report report = profile(arg_name, Array()); + return report->AsJSON(); + }); } else { return VirtualMachine::GetFunction(name, sptr_to_self); } diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index 8306f2f67fa1..ca6cb0181489 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -26,6 +26,9 @@ from tvm import relay from tvm.relay.testing import mlp from tvm.contrib.debugger import debug_executor +from tvm import rpc +from tvm.contrib import utils +from tvm.runtime.profiling import Report def read_csv(report): @@ -102,7 +105,6 @@ def test_papi(target, dev): func_name="main", collectors=[tvm.runtime.profiling.PAPIMetricCollector({dev: [metric]})], ) - print(report) assert metric in str(report) csv = read_csv(report) @@ -126,10 +128,60 @@ def test_json(): assert "microseconds" in parsed["calls"][0]["Duration (us)"] assert len(parsed["calls"]) > 0 for call in parsed["calls"]: - assert isinstance(call["Name"], str) + assert isinstance(call["Name"]["string"], str) assert isinstance(call["Count"]["count"], int) assert isinstance(call["Duration (us)"]["microseconds"], float) +@tvm.testing.requires_llvm +def test_rpc_vm(): + server = rpc.Server(key="profiling") + remote = rpc.connect("127.0.0.1", server.port, key="profiling") + + mod, params = mlp.get_workload(1) + exe = relay.vm.compile(mod, "llvm", params=params) + temp = utils.tempdir() + path = temp.relpath("lib.tar") + exe.mod.export_library(path) + remote.upload(path) + rexec = remote.load_module("lib.tar") + vm = profiler_vm.VirtualMachineProfiler(rexec, remote.cpu()) + report = vm.profile(tvm.nd.array(np.ones((1, 1, 28, 28), dtype="float32"), device=remote.cpu())) + assert len(report.calls) > 0 + + +def test_rpc_graph(): + server = rpc.Server(key="profiling") + remote = rpc.connect("127.0.0.1", server.port, key="profiling") + + mod, params = mlp.get_workload(1) + exe = relay.build(mod, "llvm", params=params) + temp = utils.tempdir() + path = temp.relpath("lib.tar") + exe.export_library(path) + remote.upload(path) + rexec = remote.load_module("lib.tar") + + gr = debug_executor.create(exe.get_graph_json(), rexec, remote.cpu()) + + data = np.random.rand(1, 1, 28, 28).astype("float32") + report = gr.profile(data=data) + assert len(report.calls) > 0 + + +def test_report_serialization(): + mod, params = mlp.get_workload(1) + + exe = relay.vm.compile(mod, "llvm", params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, tvm.cpu()) + + data = np.random.rand(1, 1, 28, 28).astype("float32") + report = vm.profile(data, func_name="main") + + report2 = Report.from_json(report.json()) + # equality on reports compares pointers, so we compare the printed results instead. + assert str(report) == str(report2) + + if __name__ == "__main__": test_papi("llvm", tvm.cpu())