From 68fe5107a5fd88ad0fe782154abb058ff08c1a32 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 25 Mar 2021 17:51:44 +0300 Subject: [PATCH 1/6] [METAL] Fix issue with GPU fails Added first run to auto scheduler. This run is necessary for checking that the generated kernel is correct. When we just run time evaluator with incorrect kernel then it is possible that our application on iOS device will be added to ignore list because of big number of committed incorrect kernels. One run before running auto scheduling helps us to avoid this problem. Added complete handlers to all command buffers in Metal runtime. It helps to handle GPU errors and report about this error to the host application. In case when error happened, we have to create a new stream. Added mechanism for error handling and streams creating from python interface. --- python/tvm/_ffi/runtime_ctypes.py | 18 +++++++- python/tvm/auto_scheduler/measure.py | 9 ++++ src/runtime/metal/metal_common.h | 45 +++++++++++++++---- src/runtime/metal/metal_device_api.mm | 50 ++++++++++++++++----- src/runtime/metal/metal_module.mm | 5 ++- src/runtime/minrpc/minrpc_server.h | 62 +++++++++++++++++++++++++++ src/runtime/minrpc/rpc_reference.h | 9 ++++ src/runtime/rpc/rpc_device_api.cc | 15 +++++++ src/runtime/rpc/rpc_endpoint.cc | 39 +++++++++++++++++ 9 files changed, 229 insertions(+), 23 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 59dc652aeb0b..6192fa984f1b 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -262,9 +262,23 @@ def max_thread_dimensions(self): """ return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) - def sync(self): + def create_stream(self): + """Create a new runtime stream at the context.""" + stream = ctypes.c_void_p() + check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream))) + return stream + + def free_stream(self, stream): + """Free a created stream handle.""" + check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream)) + + def set_stream(self, stream): + """Set a created stream handle.""" + check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream)) + + def sync(self, stream=None): """Synchronize until jobs finished at the context.""" - check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) + check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream)) def __eq__(self, other): return ( diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 83f1bcec7ebc..3a37f29f99d9 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1076,6 +1076,8 @@ def _timed_rpc_run( if error_no == 0: try: + stream = dev.create_stream() + dev.set_stream(stream) random_fill = remote.get_function("tvm.contrib.random.random_fill") assert ( random_fill @@ -1108,14 +1110,21 @@ def _timed_rpc_run( "task_inputs not fully matched, check if there's any unexpected error" ) dev.sync() + + # First run for check that the kernel is correct + func.entry_func(*args) + dev.sync() + costs = time_f(*args).results # clean up remote files remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") remote.remove("") + dev.free_stream(stream) # pylint: disable=broad-except except Exception: + dev.free_stream(stream) costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE error_msg = make_traceback_info() diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 55f9022a6b96..f5d12dbcfad5 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -45,6 +45,32 @@ namespace tvm { namespace runtime { namespace metal { +/*! + * \brief Structure for error handling in queues + */ +class Stream { + public: + explicit Stream(id device) : error_happened_(false) { + queue_ = [device newCommandQueue]; + } + ~Stream() { [queue_ release]; } + id GetCommandBuffer() { + id cb = [queue_ commandBuffer]; + [cb addCompletedHandler:^(id buffer) { + if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus(); + }]; + return cb; + } + bool IsErrorHappened() { return error_happened_; } + + private: + void SetErrorStatus() { error_happened_ = true; } + // Queue + id queue_; + // Check if error happened in one previous run + bool error_happened_; +}; + /*! * \brief Process global Metal workspace. */ @@ -52,8 +78,6 @@ class MetalWorkspace final : public DeviceAPI { public: // the devices std::vector > devices; - // the queues - std::vector > queues; // Warp size constant std::vector warp_size; // Whether it is initialized. @@ -62,13 +86,6 @@ class MetalWorkspace final : public DeviceAPI { std::mutex mutex; // Destructor ~MetalWorkspace(); - // Get command queue for given device. - id GetCommandQueue(Device dev) { - ICHECK_EQ(dev.device_type, kDLMetal); - ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) - << "Invalid Metal device_id=" << dev.device_id; - return queues[dev.device_id]; - } // Get device for given device id GetDevice(Device dev) { ICHECK_EQ(dev.device_type, kDLMetal); @@ -84,9 +101,13 @@ class MetalWorkspace final : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(Device dev, void* ptr) final; + TVMStreamHandle CreateStream(Device dev) final; + void FreeStream(Device dev, TVMStreamHandle stream) final; void StreamSync(Device dev, TVMStreamHandle stream) final; + void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + // get the global workspace static MetalWorkspace* Global(); @@ -94,6 +115,10 @@ class MetalWorkspace final : public DeviceAPI { void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) final; + + private: + // Pointers to default allocated streams + std::vector default_streams_; }; /*! \brief Thread local workspace */ @@ -101,6 +126,8 @@ class MetalThreadEntry { public: /*! \brief The current device */ Device device; + /*! \brief The current stream */ + std::vector stream; /*! \brief The shared buffer used for copy. */ std::vector > temp_buffer_; /*! \brief workspace pool */ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index cf8520864e99..03d91a2d365e 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -121,8 +121,8 @@ int GetWarpSize(id dev) { for (auto x : devices) { [x release]; } - for (auto x : queues) { - [x release]; + for (auto x : default_streams_) { + delete x; } } @@ -136,13 +136,17 @@ int GetWarpSize(id dev) { // on iPhone id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); #else NSArray >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; warp_size.push_back(GetWarpSize(d)); } @@ -183,16 +187,25 @@ int GetWarpSize(id dev) { } } +Stream* getStream(TVMStreamHandle stream, int device_id) { + if (stream != nullptr) + return static_cast(stream); + else + return MetalThreadEntry::ThreadLocal()->stream[device_id]; +} + void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { @autoreleasepool { this->Init(); - ICHECK(stream == nullptr); Device dev = dev_from; + Stream* s = getStream(stream, dev.device_id); + if (s->IsErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; + } if (dev_from.device_type == kDLCPU) dev = dev_to; - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); @@ -249,17 +262,34 @@ int GetWarpSize(id dev) { } } +TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { + Stream* stream = new Stream(devices[dev.device_id]); + return static_cast(stream); +} + +void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { + ICHECK(stream != nullptr); + Stream* s = static_cast(stream); + delete s; +} + void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { @autoreleasepool { - ICHECK(stream == nullptr); + Stream* s = getStream(stream, dev.device_id); // commit an empty command buffer and wait until it completes. - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); [cb commit]; [cb waitUntilCompleted]; + if (s->IsErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned!"; + } } } +void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); +} + void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index a8b01815bf68..29d726a0ee97 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -185,6 +185,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons @autoreleasepool { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; + auto stream = static_cast(t->stream[device_id]); + if (stream->IsErrorHappened()) return; if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } @@ -192,8 +194,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); - id queue = w_->GetCommandQueue(t->device); - id cb = [queue commandBuffer]; + id cb = stream->GetCommandBuffer(); id encoder = [cb computeCommandEncoder]; [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) { diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 732e1e49d4a4..1dfee70a20e2 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -297,10 +297,22 @@ class MinRPCServer { this->SyscallDevFreeData(values, tcodes, num_args); break; } + case RPCCode::kDevCreateStream: { + this->SyscallDevCreateStream(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeStream: { + this->SyscallDevFreeStream(values, tcodes, num_args); + break; + } case RPCCode::kDevStreamSync: { this->SyscallDevStreamSync(values, tcodes, num_args); break; } + case RPCCode::kDevSetStream: { + this->SyscallDevSetStream(values, tcodes, num_args); + break; + } case RPCCode::kCopyAmongRemote: { this->SyscallCopyAmongRemote(values, tcodes, num_args); break; @@ -444,6 +456,39 @@ class MinRPCServer { } } + void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kDLDevice); + + DLDevice dev = values[0].v_device; + void* handle; + + int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { MINRPC_CHECK(num_args == 2); MINRPC_CHECK(tcodes[0] == kDLDevice); @@ -461,6 +506,23 @@ class MinRPCServer { } } + void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { io_->Exit(static_cast(code)); } diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index e42508a73959..494b87e54b60 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -49,7 +49,10 @@ enum class RPCCode : int { kDevGetAttr, kDevAllocData, kDevFreeData, + kDevCreateStream, + kDevFreeStream, kDevStreamSync, + kDevSetStream, kCopyAmongRemote, kDevAllocDataWithScope, }; @@ -104,8 +107,14 @@ inline const char* RPCCodeToString(RPCCode code) { return "kDevAllocData"; case RPCCode::kDevFreeData: return "kDevFreeData"; + case RPCCode::kDevCreateStream: + return "kDevCreateStream"; + case RPCCode::kDevFreeStream: + return "kDevFreeStream"; case RPCCode::kDevStreamSync: return "kDevStreamSync"; + case RPCCode::kDevSetStream: + return "kDevSetStream"; case RPCCode::kCopyAmongRemote: return "kCopyAmongRemote"; case RPCCode::kDevAllocDataWithScope: diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 1d6fb85d9495..a2d1ac17ef7f 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -111,11 +111,26 @@ class RPCDeviceAPI final : public DeviceAPI { } } + TVMStreamHandle CreateStream(Device dev) { + auto remote_dev = RemoveRPCSessionMask(dev); + return GetSess(dev)->GetDeviceAPI(remote_dev)->CreateStream(remote_dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->FreeStream(remote_dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { auto remote_dev = RemoveRPCSessionMask(dev); GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream); + } + protected: void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index b5768146b3f7..ba33b5065ebb 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -920,6 +920,24 @@ void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); } +void RPCDevCreateStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + void* data = handler->GetDeviceAPI(dev)->CreateStream(dev); + *rv = data; +} + +void RPCDevFreeStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->FreeStream(dev, stream); +} + +void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->SetStream(dev, stream); +} + void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { // Event handler sit at clean state at this point. switch (code) { @@ -945,9 +963,18 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; + case RPCCode::kDevCreateStream: + SysCallHandler(RPCDevCreateStream); + break; + case RPCCode::kDevFreeStream: + SysCallHandler(RPCDevFreeStream); + break; case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break; + case RPCCode::kDevSetStream: + SysCallHandler(RPCDevSetStream); + break; case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; @@ -1033,10 +1060,22 @@ class RPCClientSession : public RPCSession, public DeviceAPI { endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream); } + TVMStreamHandle CreateStream(Device dev) final { + return endpoint_->SysCallRemote(RPCCode::kDevCreateStream, dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeStream, dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { endpoint_->SysCallRemote(RPCCode::kDevStreamSync, dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream); + } + DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; } bool IsLocalSession() const final { return false; } From 932c2cdc43fe2c7e28ead374fa843de3fb41e90f Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Tue, 13 Apr 2021 07:35:49 +0300 Subject: [PATCH 2/6] Try to fix QEMU build --- src/runtime/crt/common/crt_runtime_api.c | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index c8044b49a8d0..172972cc57b9 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -129,6 +129,15 @@ int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream return 0; } +int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) { + out = NULL; + return 0; +} + +int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { return 0; } + +int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { return 0; } + int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { return 0; } static TVMMutableFuncRegistry global_func_registry; From 4a9d9728bd608f21e66c758199e1774bee102842 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 14 Apr 2021 08:40:49 +0300 Subject: [PATCH 3/6] Apply comment --- src/runtime/minrpc/rpc_reference.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 494b87e54b60..ace3e2bbb1b8 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -49,12 +49,12 @@ enum class RPCCode : int { kDevGetAttr, kDevAllocData, kDevFreeData, - kDevCreateStream, - kDevFreeStream, kDevStreamSync, - kDevSetStream, kCopyAmongRemote, kDevAllocDataWithScope, + kDevCreateStream, + kDevFreeStream, + kDevSetStream, }; /*! From dfea55567d88bf30a50c7b9d9b70ea9c4c2c9434 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 15 Apr 2021 11:20:00 +0300 Subject: [PATCH 4/6] Apply comments and fix build --- python/tvm/_ffi/runtime_ctypes.py | 40 +++++++++++++++++++++++----- python/tvm/auto_scheduler/measure.py | 14 +++++++--- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 6192fa984f1b..49a86fc92d46 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -262,22 +262,48 @@ def max_thread_dimensions(self): """ return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) - def create_stream(self): - """Create a new runtime stream at the context.""" + def create_raw_stream(self): + """Create a new runtime stream at the context. + + User should free the stream after use. + + Returns + ------- + stream : TVMStreamHandle + The created runtime stream. + """ stream = ctypes.c_void_p() check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream))) return stream - def free_stream(self, stream): - """Free a created stream handle.""" + def free_raw_stream(self, stream): + """Free a created stream handle. + + Parameters + ---------- + stream : TVMStreamHandle + The stream which should to be released. + """ check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream)) - def set_stream(self, stream): - """Set a created stream handle.""" + def set_raw_stream(self, stream): + """Set a created stream handle. + + Parameters + ---------- + stream : TVMStreamHandle + The stream which should to be set to the device. + """ check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream)) def sync(self, stream=None): - """Synchronize until jobs finished at the context.""" + """Synchronize until jobs finished at the context. + + Parameters + ---------- + stream : TVMStreamHandle + Jobs in this stream should be finished. + """ check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream)) def __eq__(self, other): diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 3a37f29f99d9..427e7b259b8e 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1076,8 +1076,12 @@ def _timed_rpc_run( if error_no == 0: try: - stream = dev.create_stream() - dev.set_stream(stream) + stream = None + try: + stream = dev.create_raw_stream() + dev.set_raw_stream(stream) + except Exception: + pass random_fill = remote.get_function("tvm.contrib.random.random_fill") assert ( random_fill @@ -1121,10 +1125,12 @@ def _timed_rpc_run( remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") remote.remove("") - dev.free_stream(stream) + if stream is not None: + dev.free_raw_stream(stream) # pylint: disable=broad-except except Exception: - dev.free_stream(stream) + if stream is not None: + dev.free_raw_stream(stream) costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE error_msg = make_traceback_info() From ce4fafdf5646833df4eef0f01e998705f438abc9 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 15 Apr 2021 13:58:05 +0300 Subject: [PATCH 5/6] Apply comments and fix lint --- python/tvm/auto_scheduler/measure.py | 2 +- src/runtime/metal/metal_common.h | 2 +- src/runtime/metal/metal_device_api.mm | 10 +++++----- src/runtime/metal/metal_module.mm | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 427e7b259b8e..7aad59f76361 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1080,7 +1080,7 @@ def _timed_rpc_run( try: stream = dev.create_raw_stream() dev.set_raw_stream(stream) - except Exception: + finally: pass random_fill = remote.get_function("tvm.contrib.random.random_fill") assert ( diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index f5d12dbcfad5..9ebe04efbe4c 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -61,7 +61,7 @@ class Stream { }]; return cb; } - bool IsErrorHappened() { return error_happened_; } + bool HasErrorHappened() { return error_happened_; } private: void SetErrorStatus() { error_happened_ = true; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 03d91a2d365e..85b427509133 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -187,7 +187,7 @@ int GetWarpSize(id dev) { } } -Stream* getStream(TVMStreamHandle stream, int device_id) { +Stream* GetStream(TVMStreamHandle stream, int device_id) { if (stream != nullptr) return static_cast(stream); else @@ -200,8 +200,8 @@ int GetWarpSize(id dev) { @autoreleasepool { this->Init(); Device dev = dev_from; - Stream* s = getStream(stream, dev.device_id); - if (s->IsErrorHappened()) { + Stream* s = GetStream(stream, dev.device_id); + if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; } if (dev_from.device_type == kDLCPU) dev = dev_to; @@ -275,12 +275,12 @@ int GetWarpSize(id dev) { void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { @autoreleasepool { - Stream* s = getStream(stream, dev.device_id); + Stream* s = GetStream(stream, dev.device_id); // commit an empty command buffer and wait until it completes. id cb = s->GetCommandBuffer(); [cb commit]; [cb waitUntilCompleted]; - if (s->IsErrorHappened()) { + if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned!"; } } diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 29d726a0ee97..e22caa21a81e 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -186,7 +186,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; auto stream = static_cast(t->stream[device_id]); - if (stream->IsErrorHappened()) return; + if (stream->HasErrorHappened()) return; if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } From 1fdec5321d4d8150f845dc2bcc8b73ffc7012ea1 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Fri, 16 Apr 2021 15:23:34 +0300 Subject: [PATCH 6/6] Fix CI --- python/tvm/auto_scheduler/measure.py | 14 ++++---------- src/runtime/c_runtime_api.cc | 10 ++-------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 7aad59f76361..84dff157aa50 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1076,12 +1076,8 @@ def _timed_rpc_run( if error_no == 0: try: - stream = None - try: - stream = dev.create_raw_stream() - dev.set_raw_stream(stream) - finally: - pass + stream = dev.create_raw_stream() + dev.set_raw_stream(stream) random_fill = remote.get_function("tvm.contrib.random.random_fill") assert ( random_fill @@ -1125,12 +1121,10 @@ def _timed_rpc_run( remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") remote.remove("") - if stream is not None: - dev.free_raw_stream(stream) + dev.free_raw_stream(stream) # pylint: disable=broad-except except Exception: - if stream is not None: - dev.free_raw_stream(stream) + dev.free_raw_stream(stream) costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE error_msg = make_traceback_info() diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index b9e8c2549fd5..d042cb406089 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -190,17 +190,11 @@ void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, s void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } -TVMStreamHandle DeviceAPI::CreateStream(Device dev) { - LOG(FATAL) << "Device does not support stream api."; - return nullptr; -} +TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } -void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) { - LOG(FATAL) << "Device does not support stream api."; -} +void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { - LOG(FATAL) << "Device does not support stream api."; } //--------------------------------------------------------