diff --git a/CMakeLists.txt b/CMakeLists.txt index 7575d6c2b4d6..7fba5355f077 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -387,6 +387,12 @@ if(BUILD_FOR_HEXAGON) add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0) endif() +# distributed disco runtime are disabled for hexagon +if (NOT BUILD_FOR_HEXAGON) + tvm_file_glob(GLOB RUNTIME_DISCO_DISTRIBUTED_SRCS src/runtime/disco/distributed/*.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_DISCO_DISTRIBUTED_SRCS}) +endif() + # Package runtime rules if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 301b5b8d626b..c9875cedfbcb 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -52,6 +52,7 @@ class DiscoWorker { explicit DiscoWorker(int worker_id, int num_workers, int num_groups, WorkerZeroData* worker_zero_data, DiscoChannel* channel) : worker_id(worker_id), + local_worker_id(worker_id), num_workers(num_workers), num_groups(num_groups), default_device(Device{DLDeviceType::kDLCPU, 0}), @@ -68,6 +69,9 @@ class DiscoWorker { /*! \brief The id of the worker.*/ int worker_id; + /*! \brief The local id of the worker. This can be different from worker_id if the session is + * consisted with multiple sub-sessions. */ + int local_worker_id; /*! \brief Total number of workers */ int num_workers; /*! \brief Total number of workers */ diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 97fa79096d63..9c34f8a2af9e 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -281,6 +281,7 @@ class Session : public ObjectRef { */ TVM_DLL static Session ProcessSession(int num_workers, int num_groups, String process_pool_creator, String entrypoint); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); }; diff --git a/python/tvm/exec/disco_remote_socket_session.py b/python/tvm/exec/disco_remote_socket_session.py new file mode 100644 index 000000000000..3111ce30ac4b --- /dev/null +++ b/python/tvm/exec/disco_remote_socket_session.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Launch disco session in the remote node and connect to the server.""" +import sys +import tvm +from . import disco_worker as _ # pylint: disable=unused-import + + +if __name__ == "__main__": + if len(sys.argv) != 4: + print("Usage: ") + sys.exit(1) + + server_host = sys.argv[1] + server_port = int(sys.argv[2]) + num_workers = int(sys.argv[3]) + func = tvm.get_global_func("runtime.disco.RemoteSocketSession") + func(server_host, server_port, num_workers) diff --git a/python/tvm/runtime/disco/__init__.py b/python/tvm/runtime/disco/__init__.py index 856e69bc3598..2ba524cade66 100644 --- a/python/tvm/runtime/disco/__init__.py +++ b/python/tvm/runtime/disco/__init__.py @@ -22,4 +22,5 @@ ProcessSession, Session, ThreadedSession, + SocketSession, ) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 38c4f2a2354c..e33480ca3584 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -574,6 +574,29 @@ def _configure_structlog(self) -> None: func(config, os.getpid()) +@register_func("runtime.disco.create_socket_session_local_workers") +def _create_socket_session_local_workers(num_workers) -> Session: + """Create the local session for each distributed node over socket session.""" + return ProcessSession(num_workers) + + +@register_object("runtime.disco.SocketSession") +class SocketSession(Session): + """A Disco session backed by socket-based multi-node communication.""" + + def __init__( + self, num_nodes: int, num_workers_per_node: int, num_groups: int, host: str, port: int + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.SocketSession, # type: ignore # pylint: disable=no-member + num_nodes, + num_workers_per_node, + num_groups, + host, + port, + ) + + @register_func("runtime.disco._configure_structlog") def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: """Configure structlog for all disco workers diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h index 1a4df634b738..0e4ca614d418 100644 --- a/src/runtime/disco/bcast_session.h +++ b/src/runtime/disco/bcast_session.h @@ -65,6 +65,16 @@ class BcastSessionObj : public SessionObj { * \param TVMArgs The input arguments in TVM's PackedFunc calling convention */ virtual void BroadcastPacked(const TVMArgs& args) = 0; + + /*! + * \brief Send a packed sequence to a worker. This function is usually called by the controler to + * communicate with worker-0, because the worker-0 is assumed to be always collocated with the + * controler. Sending to other workers may not be supported. + * \param worker_id The worker id to send the packed sequence to. + * \param args The packed sequence to send. + */ + virtual void SendPacked(int worker_id, const TVMArgs& args) = 0; + /*! * \brief Receive a packed sequence from a worker. This function is usually called by the * controler to communicate with worker-0, because the worker-0 is assumed to be always @@ -83,6 +93,16 @@ class BcastSessionObj : public SessionObj { struct Internal; friend struct Internal; + friend class SocketSessionObj; + friend class RemoteSocketSession; +}; + +/*! + * \brief Managed reference to BcastSessionObj. + */ +class BcastSession : public Session { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session, BcastSessionObj); }; } // namespace runtime diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index b281a3aca7da..4e6350d3bb12 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -129,7 +129,7 @@ struct DiscoWorker::Impl { } static void CopyFromWorker0(DiscoWorker* self, int reg_id) { - if (self->worker_zero_data != nullptr) { + if (self->worker_id == 0) { NDArray tgt = GetNDArrayFromHost(self); NDArray src = GetReg(self, reg_id); tgt.CopyFrom(src); @@ -137,7 +137,7 @@ struct DiscoWorker::Impl { } static void CopyToWorker0(DiscoWorker* self, int reg_id) { - if (self->worker_zero_data != nullptr) { + if (self->worker_id == 0) { NDArray src = GetNDArrayFromHost(self); NDArray tgt = GetReg(self, reg_id); tgt.CopyFrom(src); diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc new file mode 100644 index 000000000000..07196be3056b --- /dev/null +++ b/src/runtime/disco/distributed/socket_session.cc @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include + +#include "../../../support/socket.h" +#include "../bcast_session.h" +#include "../message_queue.h" + +namespace tvm { +namespace runtime { + +using namespace tvm::support; + +enum class DiscoSocketAction { + kShutdown = static_cast(DiscoAction::kShutDown), + kSend, + kReceive, +}; + +class DiscoSocketChannel : public DiscoChannel { + public: + explicit DiscoSocketChannel(const TCPSocket& socket) + : socket_(socket), message_queue_(&socket_) {} + + DiscoSocketChannel(DiscoSocketChannel&& other) = delete; + DiscoSocketChannel(const DiscoSocketChannel& other) = delete; + void Send(const TVMArgs& args) { message_queue_.Send(args); } + TVMArgs Recv() { return message_queue_.Recv(); } + void Reply(const TVMArgs& args) { message_queue_.Send(args); } + TVMArgs RecvReply() { return message_queue_.Recv(); } + + private: + TCPSocket socket_; + DiscoStreamMessageQueue message_queue_; +}; + +class SocketSessionObj : public BcastSessionObj { + public: + explicit SocketSessionObj(int num_nodes, int num_workers_per_node, int num_groups, + const String& host, int port) + : num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) { + const PackedFunc* f_create_local_session = + Registry::Get("runtime.disco.create_socket_session_local_workers"); + ICHECK(f_create_local_session != nullptr) + << "Cannot find function runtime.disco.create_socket_session_local_workers"; + local_session_ = ((*f_create_local_session)(num_workers_per_node)).AsObjectRef(); + DRef f_init_workers = + local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers"); + local_session_->CallPacked(f_init_workers, num_nodes_, /*node_id=*/0, num_groups, + num_workers_per_node_); + + Socket::Startup(); + socket_.Create(); + socket_.SetKeepAlive(true); + socket_.Bind(SockAddr(host.c_str(), port)); + socket_.Listen(); + LOG(INFO) << "SocketSession controller listening on " << host << ":" << port; + + TVMValue values[4]; + int type_codes[4]; + TVMArgsSetter setter(values, type_codes); + setter(0, num_nodes); + setter(1, num_workers_per_node); + setter(2, num_groups); + + for (int i = 0; i + 1 < num_nodes; ++i) { + SockAddr addr; + remote_sockets_.push_back(socket_.Accept(&addr)); + remote_channels_.emplace_back(std::make_unique(remote_sockets_.back())); + setter(3, i + 1); + // Send metadata to each remote node: + // - num_nodes + // - num_workers_per_node + // - num_groups + // - node_id + remote_channels_.back()->Send(TVMArgs(values, type_codes, 4)); + LOG(INFO) << "Remote node " << addr.AsString() << " connected"; + } + } + + int64_t GetNumWorkers() final { return num_nodes_ * num_workers_per_node_; } + + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + return local_session_->DebugGetFromRemote(reg_id, worker_id); + } else { + std::vector values(5); + std::vector type_codes(5); + PackArgs(values.data(), type_codes.data(), static_cast(DiscoSocketAction::kSend), + worker_id, static_cast(DiscoAction::kDebugGetFromRemote), reg_id, worker_id); + + remote_channels_[node_id - 1]->Send(TVMArgs(values.data(), type_codes.data(), values.size())); + TVMArgs args = this->RecvReplyPacked(worker_id); + ICHECK_EQ(args.size(), 2); + ICHECK(static_cast(args[0].operator int()) == DiscoAction::kDebugGetFromRemote); + TVMRetValue result; + result = args[1]; + return result; + } + } + + void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + local_session_->DebugSetRegister(reg_id, value, worker_id); + } else { + ObjectRef wrapped{nullptr}; + if (value.type_code() == kTVMNDArrayHandle || value.type_code() == kTVMObjectHandle) { + wrapped = DiscoDebugObject::Wrap(value); + TVMValue tvm_value; + int type_code = kTVMObjectHandle; + tvm_value.v_handle = const_cast(wrapped.get()); + value = TVMArgValue(tvm_value, type_code); + } + { + TVMValue values[6]; + int type_codes[6]; + PackArgs(values, type_codes, static_cast(DiscoSocketAction::kSend), worker_id, + static_cast(DiscoAction::kDebugSetRegister), reg_id, worker_id, value); + remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 6)); + } + TVMRetValue result; + TVMArgs args = this->RecvReplyPacked(worker_id); + ICHECK_EQ(args.size(), 1); + ICHECK(static_cast(args[0].operator int()) == DiscoAction::kDebugSetRegister); + } + } + + void BroadcastPacked(const TVMArgs& args) final { + local_session_->BroadcastPacked(args); + std::vector values(args.size() + 2); + std::vector type_codes(args.size() + 2); + PackArgs(values.data(), type_codes.data(), static_cast(DiscoSocketAction::kSend), -1); + std::copy(args.values, args.values + args.size(), values.begin() + 2); + std::copy(args.type_codes, args.type_codes + args.size(), type_codes.begin() + 2); + for (auto& channel : remote_channels_) { + channel->Send(TVMArgs(values.data(), type_codes.data(), values.size())); + } + } + + void SendPacked(int worker_id, const TVMArgs& args) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + local_session_->SendPacked(worker_id, args); + return; + } + std::vector values(args.size() + 2); + std::vector type_codes(args.size() + 2); + PackArgs(values.data(), type_codes.data(), static_cast(DiscoSocketAction::kSend), + worker_id); + std::copy(args.values, args.values + args.size(), values.begin() + 2); + std::copy(args.type_codes, args.type_codes + args.size(), type_codes.begin() + 2); + remote_channels_[node_id - 1]->Send(TVMArgs(values.data(), type_codes.data(), values.size())); + } + + TVMArgs RecvReplyPacked(int worker_id) final { + int node_id = worker_id / num_workers_per_node_; + if (node_id == 0) { + return local_session_->RecvReplyPacked(worker_id); + } + TVMValue values[2]; + int type_codes[2]; + PackArgs(values, type_codes, static_cast(DiscoSocketAction::kReceive), worker_id); + remote_channels_[node_id - 1]->Send(TVMArgs(values, type_codes, 2)); + return remote_channels_[node_id - 1]->Recv(); + } + + void AppendHostNDArray(const NDArray& host_array) final { + local_session_->AppendHostNDArray(host_array); + } + + void Shutdown() final { + // local session will be implicitly shutdown by its destructor + TVMValue values[2]; + int type_codes[2]; + PackArgs(values, type_codes, static_cast(DiscoSocketAction::kShutdown), -1); + for (auto& channel : remote_channels_) { + channel->Send(TVMArgs(values, type_codes, 2)); + } + for (auto& socket : remote_sockets_) { + socket.Close(); + } + remote_sockets_.clear(); + remote_channels_.clear(); + if (!socket_.IsClosed()) { + socket_.Close(); + } + Socket::Finalize(); + } + + ~SocketSessionObj() { Shutdown(); } + + static constexpr const char* _type_key = "runtime.disco.SocketSession"; + TVM_DECLARE_FINAL_OBJECT_INFO(SocketSessionObj, BcastSessionObj); + int num_nodes_; + int num_workers_per_node_; + TCPSocket socket_; + std::vector remote_sockets_; + std::vector> remote_channels_; + BcastSession local_session_{nullptr}; +}; + +TVM_REGISTER_OBJECT_TYPE(SocketSessionObj); + +class RemoteSocketSession { + public: + explicit RemoteSocketSession(const String& server_host, int server_port, int num_local_workers) { + socket_.Create(); + socket_.SetKeepAlive(true); + SockAddr server_addr{server_host.c_str(), server_port}; + Socket::Startup(); + if (!socket_.Connect(server_addr)) { + LOG(FATAL) << "Failed to connect to server " << server_addr.AsString() + << ", errno = " << Socket::GetLastErrorCode(); + } + channel_ = std::make_unique(socket_); + TVMArgs metadata = channel_->Recv(); + ICHECK_EQ(metadata.size(), 4); + num_nodes_ = metadata[0].operator int(); + num_workers_per_node_ = metadata[1].operator int(); + num_groups_ = metadata[2].operator int(); + node_id_ = metadata[3].operator int(); + CHECK_GE(num_local_workers, num_workers_per_node_); + InitLocalSession(); + } + + void MainLoop() { + while (true) { + TVMArgs args = channel_->Recv(); + DiscoSocketAction action = static_cast(args[0].operator int()); + int worker_id = args[1].operator int(); + int local_worker_id = worker_id - node_id_ * num_workers_per_node_; + switch (action) { + case DiscoSocketAction::kSend: { + args = TVMArgs(args.values + 2, args.type_codes + 2, args.size() - 2); + if (worker_id == -1) { + local_session_->BroadcastPacked(args); + } else { + local_session_->SendPacked(local_worker_id, args); + } + break; + } + case DiscoSocketAction::kReceive: { + args = local_session_->RecvReplyPacked(local_worker_id); + channel_->Reply(args); + break; + } + case DiscoSocketAction::kShutdown: { + local_session_->Shutdown(); + LOG(INFO) << "Connection closed by remote controller."; + return; + } + default: + LOG(FATAL) << "Invalid action " << static_cast(action); + } + } + } + + ~RemoteSocketSession() { + socket_.Close(); + Socket::Finalize(); + } + + private: + void InitLocalSession() { + const PackedFunc* f_create_local_session = + Registry::Get("runtime.disco.create_socket_session_local_workers"); + local_session_ = ((*f_create_local_session)(num_workers_per_node_)).AsObjectRef(); + + DRef f_init_workers = + local_session_->GetGlobalFunc("runtime.disco.socket_session_init_workers"); + local_session_->CallPacked(f_init_workers, num_nodes_, node_id_, num_groups_, + num_workers_per_node_); + } + + TCPSocket socket_; + BcastSession local_session_{nullptr}; + std::unique_ptr channel_; + int num_nodes_{-1}; + int node_id_{-1}; + int num_groups_{-1}; + int num_workers_per_node_{-1}; +}; + +void RemoteSocketSessionEntryPoint(const String& server_host, int server_port, + int num_local_workers) { + RemoteSocketSession proxy(server_host, server_port, num_local_workers); + proxy.MainLoop(); +} + +TVM_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession") + .set_body_typed(RemoteSocketSessionEntryPoint); + +Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, const String& host, + int port) { + auto n = make_object(num_nodes, num_workers_per_node, num_groups, host, port); + return Session(n); +} + +TVM_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession); + +TVM_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers") + .set_body_typed([](int num_nodes, int node_id, int num_groups, int num_workers_per_node) { + LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, " + << num_workers_per_node << " workers per node, and " << num_groups << " groups."; + DiscoWorker* worker = DiscoWorker::ThreadLocal(); + worker->num_groups = num_groups; + worker->worker_id = worker->worker_id + node_id * num_workers_per_node; + worker->num_workers = num_nodes * num_workers_per_node; + }); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/disco/message_queue.h b/src/runtime/disco/message_queue.h new file mode 100644 index 000000000000..3b78c3e5c187 --- /dev/null +++ b/src/runtime/disco/message_queue.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_ +#define TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_ + +#include + +#include + +#include "./protocol.h" + +namespace tvm { +namespace runtime { + +class DiscoStreamMessageQueue : private dmlc::Stream, + private DiscoProtocol { + public: + explicit DiscoStreamMessageQueue(Stream* stream) : stream_(stream) {} + + ~DiscoStreamMessageQueue() = default; + + void Send(const TVMArgs& args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this); + CommitSendAndNotifyEnqueue(); + } + + TVMArgs Recv() { + bool is_implicit_shutdown = DequeueNextPacket(); + TVMValue* values = nullptr; + int* type_codes = nullptr; + int num_args = 0; + + if (is_implicit_shutdown) { + num_args = 2; + values = ArenaAlloc(num_args); + type_codes = ArenaAlloc(num_args); + TVMArgsSetter setter(values, type_codes); + setter(0, static_cast(DiscoAction::kShutDown)); + setter(1, 0); + } else { + RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); + } + return TVMArgs(values, type_codes, num_args); + } + + protected: + void CommitSendAndNotifyEnqueue() { + stream_->Write(write_buffer_.data(), write_buffer_.size()); + write_buffer_.clear(); + } + + /* \brief Read next packet and reset unpacker + * + * Read the next packet into `read_buffer_`, releasing all arena + * allocations performed by the unpacker and resetting the unpacker + * to its initial state. + * + * \return A boolean value. If true, this packet should be treated + * equivalently to a `DiscoAction::kShutdown` event. If false, + * this packet should be unpacked. + */ + bool DequeueNextPacket() { + uint64_t packet_nbytes = 0; + int read_size = stream_->Read(&packet_nbytes, sizeof(packet_nbytes)); + if (read_size == 0) { + // Special case, connection dropped between packets. Treat as a + // request to shutdown. + return true; + } + + ICHECK_EQ(read_size, sizeof(packet_nbytes)) + << "Stream closed without proper shutdown. Please make sure to explicitly call " + "`Session::Shutdown`"; + read_buffer_.resize(packet_nbytes); + read_size = stream_->Read(read_buffer_.data(), packet_nbytes); + ICHECK_EQ(read_size, packet_nbytes) + << "Stream closed without proper shutdown. Please make sure to explicitly call " + "`Session::Shutdown`"; + read_offset_ = 0; + this->RecycleAll(); + RPCCode code = RPCCode::kReturn; + this->Read(&code); + return false; + } + + size_t Read(void* data, size_t size) final { + std::memcpy(data, read_buffer_.data() + read_offset_, size); + read_offset_ += size; + ICHECK_LE(read_offset_, read_buffer_.size()); + return size; + } + + size_t Write(const void* data, size_t size) final { + size_t cur_size = write_buffer_.size(); + write_buffer_.resize(cur_size + size); + std::memcpy(write_buffer_.data() + cur_size, data, size); + return size; + } + + using dmlc::Stream::Read; + using dmlc::Stream::ReadArray; + using dmlc::Stream::Write; + using dmlc::Stream::WriteArray; + friend struct RPCReference; + friend struct DiscoProtocol; + + // The read/write buffer will only be accessed by the producer thread. + std::string write_buffer_; + std::string read_buffer_; + size_t read_offset_ = 0; + dmlc::Stream* stream_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_ diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 2d2c528b5291..33baca8e369b 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -86,7 +86,8 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { << "and has not been destructed"; // Step up local context of NCCL - int device_id = device_ids[worker->worker_id]; + int group_size = worker->num_workers / worker->num_groups; + int device_id = device_ids[worker->local_worker_id]; SetDevice(device_id); #if TVM_NCCL_RCCL_SWITCH == 0 StreamCreate(&ctx->default_stream); @@ -99,7 +100,6 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { // Initialize the communicator ncclUniqueId id; std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES); - int group_size = worker->num_workers / worker->num_groups; NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, worker->worker_id)); NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, worker->worker_id % group_size, &ctx->group_comm, NULL)); diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 7c8d0796dd81..161c3f6e0408 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -31,114 +31,19 @@ #include "../minrpc/rpc_reference.h" #include "./bcast_session.h" #include "./disco_worker_thread.h" +#include "./message_queue.h" #include "./protocol.h" namespace tvm { namespace runtime { -class DiscoPipeMessageQueue : private dmlc::Stream, private DiscoProtocol { - public: - explicit DiscoPipeMessageQueue(int64_t handle) : pipe_(handle) {} - - ~DiscoPipeMessageQueue() = default; - - void Send(const TVMArgs& args) { - RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this); - CommitSendAndNotifyEnqueue(); - } - - TVMArgs Recv() { - bool is_implicit_shutdown = DequeueNextPacket(); - TVMValue* values = nullptr; - int* type_codes = nullptr; - int num_args = 0; - - if (is_implicit_shutdown) { - num_args = 2; - values = ArenaAlloc(num_args); - type_codes = ArenaAlloc(num_args); - TVMArgsSetter setter(values, type_codes); - setter(0, static_cast(DiscoAction::kShutDown)); - setter(1, 0); - } else { - RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); - } - return TVMArgs(values, type_codes, num_args); - } - - protected: - void CommitSendAndNotifyEnqueue() { - pipe_.Write(write_buffer_.data(), write_buffer_.size()); - write_buffer_.clear(); - } - - /* \brief Read next packet and reset unpacker - * - * Read the next packet into `read_buffer_`, releasing all arena - * allocations performed by the unpacker and resetting the unpacker - * to its initial state. - * - * \return A boolean value. If true, this packet should be treated - * equivalently to a `DiscoAction::kShutdown` event. If false, - * this packet should be unpacked. - */ - bool DequeueNextPacket() { - uint64_t packet_nbytes = 0; - int read_size = pipe_.Read(&packet_nbytes, sizeof(packet_nbytes)); - if (read_size == 0) { - // Special case, connection dropped between packets. Treat as a - // request to shutdown. - return true; - } - - ICHECK_EQ(read_size, sizeof(packet_nbytes)) - << "Pipe closed without proper shutdown. Please make sure to explicitly call " - "`Session::Shutdown`"; - read_buffer_.resize(packet_nbytes); - read_size = pipe_.Read(read_buffer_.data(), packet_nbytes); - ICHECK_EQ(read_size, packet_nbytes) - << "Pipe closed without proper shutdown. Please make sure to explicitly call " - "`Session::Shutdown`"; - read_offset_ = 0; - this->RecycleAll(); - RPCCode code = RPCCode::kReturn; - this->Read(&code); - return false; - } - - size_t Read(void* data, size_t size) final { - std::memcpy(data, read_buffer_.data() + read_offset_, size); - read_offset_ += size; - ICHECK_LE(read_offset_, read_buffer_.size()); - return size; - } - - size_t Write(const void* data, size_t size) final { - size_t cur_size = write_buffer_.size(); - write_buffer_.resize(cur_size + size); - std::memcpy(write_buffer_.data() + cur_size, data, size); - return size; - } - - using dmlc::Stream::Read; - using dmlc::Stream::ReadArray; - using dmlc::Stream::Write; - using dmlc::Stream::WriteArray; - friend struct RPCReference; - friend struct DiscoProtocol; - - // The read/write buffer will only be accessed by the producer thread. - std::string write_buffer_; - std::string read_buffer_; - size_t read_offset_ = 0; - support::Pipe pipe_; -}; - class DiscoProcessChannel final : public DiscoChannel { public: DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t worker_to_controler_fd) - : controler_to_worker_(controler_to_worker_fd), - worker_to_controler_(worker_to_controler_fd) {} + : controller_to_worker_pipe_(controler_to_worker_fd), + worker_to_controller_pipe_(worker_to_controler_fd), + controler_to_worker_(&controller_to_worker_pipe_), + worker_to_controler_(&worker_to_controller_pipe_) {} DiscoProcessChannel(DiscoProcessChannel&& other) = delete; DiscoProcessChannel(const DiscoProcessChannel& other) = delete; @@ -148,8 +53,10 @@ class DiscoProcessChannel final : public DiscoChannel { void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); } TVMArgs RecvReply() { return worker_to_controler_.Recv(); } - DiscoPipeMessageQueue controler_to_worker_; - DiscoPipeMessageQueue worker_to_controler_; + support::Pipe controller_to_worker_pipe_; + support::Pipe worker_to_controller_pipe_; + DiscoStreamMessageQueue controler_to_worker_; + DiscoStreamMessageQueue worker_to_controler_; }; class ProcessSessionObj final : public BcastSessionObj { @@ -226,7 +133,7 @@ class ProcessSessionObj final : public BcastSessionObj { int type_codes[4]; PackArgs(values, type_codes, static_cast(DiscoAction::kDebugSetRegister), reg_id, worker_id, value); - workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 4)); + SendPacked(worker_id, TVMArgs(values, type_codes, 4)); } TVMRetValue result; TVMArgs args = this->RecvReplyPacked(worker_id); @@ -241,6 +148,14 @@ class ProcessSessionObj final : public BcastSessionObj { } } + void SendPacked(int worker_id, const TVMArgs& args) final { + if (worker_id == 0) { + worker_0_->channel->Send(args); + } else { + workers_.at(worker_id - 1)->Send(args); + } + } + TVMArgs RecvReplyPacked(int worker_id) final { if (worker_id == 0) { return worker_0_->channel->RecvReply(); @@ -248,6 +163,13 @@ class ProcessSessionObj final : public BcastSessionObj { return this->workers_.at(worker_id - 1)->RecvReply(); } + DiscoChannel* GetWorkerChannel(int worker_id) { + if (worker_id == 0) { + return worker_0_->channel.get(); + } + return workers_.at(worker_id - 1).get(); + } + PackedFunc process_pool_; std::unique_ptr worker_0_; std::vector> workers_; diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index cc9a311a6b3f..bf6b6107e122 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -173,6 +173,10 @@ class ThreadedSessionObj final : public BcastSessionObj { } } + void SendPacked(int worker_id, const TVMArgs& args) final { + this->workers_.at(worker_id).channel->Send(args); + } + TVMArgs RecvReplyPacked(int worker_id) final { return this->workers_.at(worker_id).channel->RecvReply(); } diff --git a/src/support/socket.h b/src/support/socket.h index ac13cd3f2d35..032cf257c045 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -370,7 +370,7 @@ class Socket { /*! * \brief a wrapper of TCP socket that hopefully be cross platform */ -class TCPSocket : public Socket { +class TCPSocket : public Socket, public dmlc::Stream { public: TCPSocket() : Socket(INVALID_SOCKET) {} /*! @@ -552,6 +552,10 @@ class TCPSocket : public Socket { ICHECK_EQ(RecvAll(&data[0], datalen), datalen); return data; } + + size_t Read(void* data, size_t size) final { return Recv(data, size); } + + size_t Write(const void* data, size_t size) final { return Send(data, size); } }; /*! \brief helper data structure to perform poll */ diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index 837b3a14f271..38aa757bf8f1 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -20,6 +20,9 @@ import numpy as np import pytest +import subprocess +import threading +import sys import tvm import tvm.testing @@ -29,7 +32,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T -from tvm.exec import disco_worker as _ +from tvm.exec import disco_worker as _ # pylint: disable=unused-import def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): @@ -46,7 +49,75 @@ def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype): return host_array.numpy() -_all_session_kinds = [di.ThreadedSession, di.ProcessSession] +_SOCKET_SESSION_TESTER = None + + +def get_free_port(): + import socket + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +class SocketSessionTester: + def __init__(self, num_workers): + num_nodes = 2 + num_groups = 1 + assert num_workers % num_nodes == 0 + num_workers_per_node = num_workers // num_nodes + server_host = "localhost" + server_port = get_free_port() + self.sess = None + + def start_server(): + self.sess = di.SocketSession( + num_nodes, num_workers_per_node, num_groups, server_host, server_port + ) + + thread = threading.Thread(target=start_server) + thread.start() + + cmd = "tvm.exec.disco_remote_socket_session" + self.remote_nodes = [] + for _ in range(num_nodes - 1): + self.remote_nodes.append( + subprocess.Popen( + [ + "python3", + "-m", + cmd, + server_host, + str(server_port), + str(num_workers_per_node), + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + ) + + thread.join() + + def __del__(self): + for node in self.remote_nodes: + node.kill() + if self.sess is not None: + self.sess.shutdown() + del self.sess + + +def create_socket_session(num_workers): + global _SOCKET_SESSION_TESTER + if _SOCKET_SESSION_TESTER is not None: + del _SOCKET_SESSION_TESTER + _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers) + assert _SOCKET_SESSION_TESTER.sess is not None + return _SOCKET_SESSION_TESTER.sess + + +_all_session_kinds = [di.ThreadedSession, di.ProcessSession, create_socket_session] @pytest.mark.parametrize("session_kind", _all_session_kinds) @@ -157,6 +228,11 @@ def main(A: R.Tensor((8, 16), dtype="float32")) -> R.Tensor((16, 8), dtype="floa y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape, dtype=y_np.dtype) np.testing.assert_equal(y_nd, y_np) + # sync all workers to make sure the temporary files are cleaned up after all workers + # finish the execution + for i in range(num_workers): + sess._sync_worker(i) + @pytest.mark.parametrize("session_kind", _all_session_kinds) def test_vm_multi_func(session_kind): @@ -220,10 +296,17 @@ def transpose_2( np.testing.assert_equal(y_nd, y_np) np.testing.assert_equal(z_nd, x_np) + # sync all workers to make sure the temporary files are cleaned up after all workers + # finish the execution + for i in range(num_workers): + sess._sync_worker(i) + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("num_workers", [1, 2, 4]) def test_num_workers(session_kind, num_workers): + if session_kind == create_socket_session and num_workers < 2: + return sess = session_kind(num_workers=num_workers) assert sess.num_workers == num_workers