diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 59dc652aeb0b..49a86fc92d46 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -262,9 +262,49 @@ def max_thread_dimensions(self): """ return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) - def sync(self): - """Synchronize until jobs finished at the context.""" - check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) + def create_raw_stream(self): + """Create a new runtime stream at the context. + + User should free the stream after use. + + Returns + ------- + stream : TVMStreamHandle + The created runtime stream. + """ + stream = ctypes.c_void_p() + check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream))) + return stream + + def free_raw_stream(self, stream): + """Free a created stream handle. + + Parameters + ---------- + stream : TVMStreamHandle + The stream which should to be released. + """ + check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream)) + + def set_raw_stream(self, stream): + """Set a created stream handle. + + Parameters + ---------- + stream : TVMStreamHandle + The stream which should to be set to the device. + """ + check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream)) + + def sync(self, stream=None): + """Synchronize until jobs finished at the context. + + Parameters + ---------- + stream : TVMStreamHandle + Jobs in this stream should be finished. + """ + check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream)) def __eq__(self, other): return ( diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 83f1bcec7ebc..84dff157aa50 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1076,6 +1076,8 @@ def _timed_rpc_run( if error_no == 0: try: + stream = dev.create_raw_stream() + dev.set_raw_stream(stream) random_fill = remote.get_function("tvm.contrib.random.random_fill") assert ( random_fill @@ -1108,14 +1110,21 @@ def _timed_rpc_run( "task_inputs not fully matched, check if there's any unexpected error" ) dev.sync() + + # First run for check that the kernel is correct + func.entry_func(*args) + dev.sync() + costs = time_f(*args).results # clean up remote files remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") remote.remove("") + dev.free_raw_stream(stream) # pylint: disable=broad-except except Exception: + dev.free_raw_stream(stream) costs = (MAX_FLOAT,) error_no = MeasureErrorNo.RUNTIME_DEVICE error_msg = make_traceback_info() diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index b9e8c2549fd5..d042cb406089 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -190,17 +190,11 @@ void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, s void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } -TVMStreamHandle DeviceAPI::CreateStream(Device dev) { - LOG(FATAL) << "Device does not support stream api."; - return nullptr; -} +TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } -void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) { - LOG(FATAL) << "Device does not support stream api."; -} +void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { - LOG(FATAL) << "Device does not support stream api."; } //-------------------------------------------------------- diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index c8044b49a8d0..172972cc57b9 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -129,6 +129,15 @@ int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream return 0; } +int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) { + out = NULL; + return 0; +} + +int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { return 0; } + +int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { return 0; } + int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { return 0; } static TVMMutableFuncRegistry global_func_registry; diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 55f9022a6b96..9ebe04efbe4c 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -45,6 +45,32 @@ namespace tvm { namespace runtime { namespace metal { +/*! + * \brief Structure for error handling in queues + */ +class Stream { + public: + explicit Stream(id device) : error_happened_(false) { + queue_ = [device newCommandQueue]; + } + ~Stream() { [queue_ release]; } + id GetCommandBuffer() { + id cb = [queue_ commandBuffer]; + [cb addCompletedHandler:^(id buffer) { + if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus(); + }]; + return cb; + } + bool HasErrorHappened() { return error_happened_; } + + private: + void SetErrorStatus() { error_happened_ = true; } + // Queue + id queue_; + // Check if error happened in one previous run + bool error_happened_; +}; + /*! * \brief Process global Metal workspace. */ @@ -52,8 +78,6 @@ class MetalWorkspace final : public DeviceAPI { public: // the devices std::vector > devices; - // the queues - std::vector > queues; // Warp size constant std::vector warp_size; // Whether it is initialized. @@ -62,13 +86,6 @@ class MetalWorkspace final : public DeviceAPI { std::mutex mutex; // Destructor ~MetalWorkspace(); - // Get command queue for given device. - id GetCommandQueue(Device dev) { - ICHECK_EQ(dev.device_type, kDLMetal); - ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) - << "Invalid Metal device_id=" << dev.device_id; - return queues[dev.device_id]; - } // Get device for given device id GetDevice(Device dev) { ICHECK_EQ(dev.device_type, kDLMetal); @@ -84,9 +101,13 @@ class MetalWorkspace final : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(Device dev, void* ptr) final; + TVMStreamHandle CreateStream(Device dev) final; + void FreeStream(Device dev, TVMStreamHandle stream) final; void StreamSync(Device dev, TVMStreamHandle stream) final; + void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + // get the global workspace static MetalWorkspace* Global(); @@ -94,6 +115,10 @@ class MetalWorkspace final : public DeviceAPI { void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) final; + + private: + // Pointers to default allocated streams + std::vector default_streams_; }; /*! \brief Thread local workspace */ @@ -101,6 +126,8 @@ class MetalThreadEntry { public: /*! \brief The current device */ Device device; + /*! \brief The current stream */ + std::vector stream; /*! \brief The shared buffer used for copy. */ std::vector > temp_buffer_; /*! \brief workspace pool */ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index cf8520864e99..85b427509133 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -121,8 +121,8 @@ int GetWarpSize(id dev) { for (auto x : devices) { [x release]; } - for (auto x : queues) { - [x release]; + for (auto x : default_streams_) { + delete x; } } @@ -136,13 +136,17 @@ int GetWarpSize(id dev) { // on iPhone id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); #else NSArray >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); - queues.push_back([d newCommandQueue]); + Stream* stream = new Stream(d); + MetalThreadEntry::ThreadLocal()->stream.push_back(stream); + default_streams_.push_back(stream); LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; warp_size.push_back(GetWarpSize(d)); } @@ -183,16 +187,25 @@ int GetWarpSize(id dev) { } } +Stream* GetStream(TVMStreamHandle stream, int device_id) { + if (stream != nullptr) + return static_cast(stream); + else + return MetalThreadEntry::ThreadLocal()->stream[device_id]; +} + void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { @autoreleasepool { this->Init(); - ICHECK(stream == nullptr); Device dev = dev_from; + Stream* s = GetStream(stream, dev.device_id); + if (s->HasErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; + } if (dev_from.device_type == kDLCPU) dev = dev_to; - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); @@ -249,17 +262,34 @@ int GetWarpSize(id dev) { } } +TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { + Stream* stream = new Stream(devices[dev.device_id]); + return static_cast(stream); +} + +void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { + ICHECK(stream != nullptr); + Stream* s = static_cast(stream); + delete s; +} + void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { @autoreleasepool { - ICHECK(stream == nullptr); + Stream* s = GetStream(stream, dev.device_id); // commit an empty command buffer and wait until it completes. - id queue = GetCommandQueue(dev); - id cb = [queue commandBuffer]; + id cb = s->GetCommandBuffer(); [cb commit]; [cb waitUntilCompleted]; + if (s->HasErrorHappened()) { + LOG(FATAL) << "Error! Some problems on GPU happaned!"; + } } } +void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); +} + void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index a8b01815bf68..e22caa21a81e 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -185,6 +185,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons @autoreleasepool { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->device.device_id; + auto stream = static_cast(t->stream[device_id]); + if (stream->HasErrorHappened()) return; if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } @@ -192,8 +194,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); - id queue = w_->GetCommandQueue(t->device); - id cb = [queue commandBuffer]; + id cb = stream->GetCommandBuffer(); id encoder = [cb computeCommandEncoder]; [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) { diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 732e1e49d4a4..1dfee70a20e2 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -297,10 +297,22 @@ class MinRPCServer { this->SyscallDevFreeData(values, tcodes, num_args); break; } + case RPCCode::kDevCreateStream: { + this->SyscallDevCreateStream(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeStream: { + this->SyscallDevFreeStream(values, tcodes, num_args); + break; + } case RPCCode::kDevStreamSync: { this->SyscallDevStreamSync(values, tcodes, num_args); break; } + case RPCCode::kDevSetStream: { + this->SyscallDevSetStream(values, tcodes, num_args); + break; + } case RPCCode::kCopyAmongRemote: { this->SyscallCopyAmongRemote(values, tcodes, num_args); break; @@ -444,6 +456,39 @@ class MinRPCServer { } } + void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kDLDevice); + + DLDevice dev = values[0].v_device; + void* handle; + + int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { MINRPC_CHECK(num_args == 2); MINRPC_CHECK(tcodes[0] == kDLDevice); @@ -461,6 +506,23 @@ class MinRPCServer { } } + void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kDLDevice); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + DLDevice dev = values[0].v_device; + void* handle = values[1].v_handle; + + int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { io_->Exit(static_cast(code)); } diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index e42508a73959..ace3e2bbb1b8 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -52,6 +52,9 @@ enum class RPCCode : int { kDevStreamSync, kCopyAmongRemote, kDevAllocDataWithScope, + kDevCreateStream, + kDevFreeStream, + kDevSetStream, }; /*! @@ -104,8 +107,14 @@ inline const char* RPCCodeToString(RPCCode code) { return "kDevAllocData"; case RPCCode::kDevFreeData: return "kDevFreeData"; + case RPCCode::kDevCreateStream: + return "kDevCreateStream"; + case RPCCode::kDevFreeStream: + return "kDevFreeStream"; case RPCCode::kDevStreamSync: return "kDevStreamSync"; + case RPCCode::kDevSetStream: + return "kDevSetStream"; case RPCCode::kCopyAmongRemote: return "kCopyAmongRemote"; case RPCCode::kDevAllocDataWithScope: diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 1d6fb85d9495..a2d1ac17ef7f 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -111,11 +111,26 @@ class RPCDeviceAPI final : public DeviceAPI { } } + TVMStreamHandle CreateStream(Device dev) { + auto remote_dev = RemoveRPCSessionMask(dev); + return GetSess(dev)->GetDeviceAPI(remote_dev)->CreateStream(remote_dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->FreeStream(remote_dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { auto remote_dev = RemoveRPCSessionMask(dev); GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) { + auto remote_dev = RemoveRPCSessionMask(dev); + GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream); + } + protected: void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index b5768146b3f7..ba33b5065ebb 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -920,6 +920,24 @@ void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); } +void RPCDevCreateStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + void* data = handler->GetDeviceAPI(dev)->CreateStream(dev); + *rv = data; +} + +void RPCDevFreeStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->FreeStream(dev, stream); +} + +void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + Device dev = args[0]; + TVMStreamHandle stream = args[1]; + handler->GetDeviceAPI(dev)->SetStream(dev, stream); +} + void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { // Event handler sit at clean state at this point. switch (code) { @@ -945,9 +963,18 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; + case RPCCode::kDevCreateStream: + SysCallHandler(RPCDevCreateStream); + break; + case RPCCode::kDevFreeStream: + SysCallHandler(RPCDevFreeStream); + break; case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break; + case RPCCode::kDevSetStream: + SysCallHandler(RPCDevSetStream); + break; case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; @@ -1033,10 +1060,22 @@ class RPCClientSession : public RPCSession, public DeviceAPI { endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream); } + TVMStreamHandle CreateStream(Device dev) final { + return endpoint_->SysCallRemote(RPCCode::kDevCreateStream, dev); + } + + void FreeStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeStream, dev, stream); + } + void StreamSync(Device dev, TVMStreamHandle stream) final { endpoint_->SysCallRemote(RPCCode::kDevStreamSync, dev, stream); } + void SetStream(Device dev, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream); + } + DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; } bool IsLocalSession() const final { return false; }