diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index 7d15e35fbdbc..4453d9737f89 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -114,6 +114,30 @@ TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional recv * \param buffer The buffer to be received */ TVM_DLL void RecvFromWorker0(NDArray buffer); +/*! + * \brief Send a buffer to the corresponding worker in the next group. + * An error is thrown if the worker is already in the last group. + * \param buffer The sending buffer. + */ +TVM_DLL void SendToNextGroup(NDArray buffer); +/*! + * \brief Receive a buffer from the corresponding worker in the previous group. + * An error is thrown if the worker is already in the first group. + * \param buffer The receiving buffer. + */ +TVM_DLL void RecvFromPrevGroup(NDArray buffer); +/*! + * \brief Send a buffer to the target receiver worker (globally across all groups). + * \param buffer The sending buffer. + * \param receiver_id The global receiver worker id. + */ +TVM_DLL void SendToWorker(NDArray buffer, int receiver_id); +/*! + * \brief Receive a buffer from the target sender worker (globally across all groups). + * \param buffer The receiving buffer. + * \param sender_id The global sender worker id. + */ +TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id); /*! \brief Get the local worker id */ TVM_DLL int WorkerId(); /*! diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 46e016a242ea..3511c38a2b7c 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -549,16 +549,16 @@ def __init__(self, modules: List[Module]): def __iter__(self): return iter(self.modules) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Module: return self.modules[idx] - def __setitem__(self, idx, module): + def __setitem__(self, idx: int, module: Module) -> None: self.modules[idx] = module def __len__(self): return len(self.modules) - def append(self, module): + def append(self, module: Module): """Add a module to the end of the ModuleList""" self.modules.append(module) diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 0cb2ee6f5d6b..760a330a7a8e 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -101,6 +101,18 @@ void GatherToWorker0(NDArray send, bool in_group, Optional recv) { void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); } +void SendToNextGroup(NDArray buffer) { GetCCLFunc("send_to_next_group")(buffer); } + +void RecvFromPrevGroup(NDArray buffer) { GetCCLFunc("recv_from_prev_group")(buffer); } + +void SendToWorker(NDArray buffer, int receiver_id) { + GetCCLFunc("send_to_worker")(buffer, receiver_id); +} + +void RecvFromWorker(NDArray buffer, int sender_id) { + GetCCLFunc("recv_from_worker")(buffer, sender_id); +} + int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; } void SyncWorker() { @@ -136,6 +148,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(Broad TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0); TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0); TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0); +TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup); +TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup); +TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker); +TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker); TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple { return ShapeTuple({WorkerId()}); }); diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 2d2c528b5291..35e8fd06b309 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) { NCCL_CALL(ncclGroupEnd()); } +void SendToNextGroup(NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + int receiver_id = worker_id + group_size; + CHECK_LT(receiver_id, ctx->worker->num_workers) + << "The current group is already the last group and there is no such a next group."; + NCCL_CALL(ncclGroupStart()); + NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + receiver_id, ctx->global_comm, stream)); + NCCL_CALL(ncclGroupEnd()); +} + +void RecvFromPrevGroup(NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + int sender_id = worker_id - group_size; + CHECK_GE(sender_id, 0) + << "The current group is already the first group and there is no such a previous group."; + NCCL_CALL(ncclGroupStart()); + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + sender_id, ctx->global_comm, stream)); + NCCL_CALL(ncclGroupEnd()); +} + +void SendToWorker(NDArray buffer, int receiver_id) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers) + << "Invalid receiver id " << receiver_id << ". The world size is " + << ctx->worker->num_workers; + CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself."; + NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + receiver_id, ctx->global_comm, stream)); +} + +void RecvFromWorker(NDArray buffer, int sender_id) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + deviceStream_t stream = ctx->GetDefaultStream(); + int worker_id = ctx->worker->worker_id; + CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers) + << "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers; + CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself."; + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + sender_id, ctx->global_comm, stream)); +} + void SyncWorker() { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); ICHECK(ctx->worker != nullptr); @@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0") .set_body_typed(GatherToWorker0); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0") .set_body_typed(RecvFromWorker0); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group") + .set_body_typed(SendToNextGroup); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group") + .set_body_typed(RecvFromPrevGroup); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker") + .set_body_typed(SendToWorker); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker") + .set_body_typed(RecvFromWorker); TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker); +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME + ".test_send_to_next_group_recv_from_prev_group") + .set_body_typed([](NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; + CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + int group_size = ctx->worker->num_workers / ctx->worker->num_groups; + int group_id = ctx->worker->worker_id / group_size; + if (group_id == 0) { + tvm::runtime::nccl::SendToNextGroup(buffer); + } else { + tvm::runtime::nccl::RecvFromPrevGroup(buffer); + } + }); + +TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0") + .set_body_typed([](NDArray buffer) { + CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); + CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; + CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + if (ctx->worker->worker_id == 2) { + tvm::runtime::nccl::SendToWorker(buffer, 0); + } else if (ctx->worker->worker_id == 0) { + tvm::runtime::nccl::RecvFromWorker(buffer, 2); + } + }); + } // namespace nccl } // namespace runtime } // namespace tvm diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index 6c63f64554a3..c29ece957245 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -25,11 +25,11 @@ import tvm import tvm.testing from tvm import dlight as dl +from tvm import get_global_func from tvm import relax as rx from tvm.runtime import disco as di from tvm.runtime.relax_vm import VirtualMachine from tvm.script import relax as R -from tvm import get_global_func _all_session_kinds = [di.ThreadedSession, di.ProcessSession] _ccl = [get_global_func("runtime.disco.compiled_ccl")()] @@ -391,6 +391,44 @@ def test_group_gather(session_kind, ccl, capfd): ), "No warning messages should be generated from disco.Session.gather_to_worker0" +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_send_to_next_group_receive_from_prev_group(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array_1 = np.arange(12, dtype="float32").reshape(3, 4) + array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) + d_array = sess.empty((3, 4), "float32") + d_array.debug_copy_from(0, array_1) + d_array.debug_copy_from(1, array_2) + sess.get_global_func("runtime.disco." + ccl + ".test_send_to_next_group_recv_from_prev_group")( + d_array + ) + + result_1 = d_array.debug_get_from_remote(2).numpy() + result_2 = d_array.debug_get_from_remote(3).numpy() + np.testing.assert_equal(result_1, array_1) + np.testing.assert_equal(result_2, array_2) + + +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_worker2_send_to_worker0(session_kind, ccl): + devices = [0, 1, 2, 3] + sess = session_kind(num_workers=len(devices), num_groups=2) + sess.init_ccl(ccl, *devices) + + array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) + d_array = sess.empty((3, 4), "float32") + d_array.debug_copy_from(2, array) + sess.get_global_func("runtime.disco." + ccl + ".test_worker2_sends_to_worker0")(d_array) + + result = d_array.debug_get_from_remote(0).numpy() + np.testing.assert_equal(result, array) + + @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals