From b53eaece1f6a866e0918e9ddaa291db49fdb2af2 Mon Sep 17 00:00:00 2001 From: mcowan Date: Fri, 3 May 2024 22:07:14 +0000 Subject: [PATCH 1/3] hide vector in DeviceMesh --- csrc/multidevice/device_mesh.h | 16 ++++++++++++++++ tests/cpp/multidevice.cpp | 9 +++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/csrc/multidevice/device_mesh.h b/csrc/multidevice/device_mesh.h index 19625de1197..39e411e3d6d 100644 --- a/csrc/multidevice/device_mesh.h +++ b/csrc/multidevice/device_mesh.h @@ -34,6 +34,11 @@ class DeviceMesh final { std::string toString() const; + // returns the number of devices in the mesh + int64_t size() { + return vector_.size(); + } + // returns a vector containing the device indices of the mesh const auto& vector() const { return vector_; @@ -44,6 +49,17 @@ class DeviceMesh final { return std::find(vector_.begin(), vector_.end(), device) != vector_.end(); } + + // returns the index of device in the mesh. + // returns -1 if device is not present. + int64_t idxOf(const DeviceIdxType device) { + auto it = std::find(vector_.begin(), vector_.end(), device); + if (it != vector_.end()) { + return std::distance(vector_.begin(), it); + } + return -1; + } + bool operator==(const DeviceMesh& other) const { return vector_ == other.vector(); } diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index d22e1c05800..7b013ac83b4 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -50,12 +50,9 @@ void MultiDeviceTest::SetUp() { return tensor; } auto sharded_dim = getShardedAxis(tv); - int i = 0; - const auto& devices = tv->getDeviceMesh().vector(); - auto it = std::find(devices.begin(), devices.end(), deviceId); - if (it != devices.end()) { - i = std::distance(devices.begin(), it); - } + int i = tv->getDeviceMesh().idxOf(deviceId); + // TODO: returning slice 0 temporarily when device is not in the mesh. + i = (i < 0) : 0 ? i; return tensor.slice(sharded_dim, i, i + 1).contiguous(); } From 90948ccd5143eb1a012d7c0307d9018111db6dbe Mon Sep 17 00:00:00 2001 From: mcowan Date: Fri, 3 May 2024 23:06:29 +0000 Subject: [PATCH 2/3] fix errors --- csrc/multidevice/device_mesh.h | 3 +-- tests/cpp/multidevice.cpp | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/device_mesh.h b/csrc/multidevice/device_mesh.h index 39e411e3d6d..d2d547c7bcf 100644 --- a/csrc/multidevice/device_mesh.h +++ b/csrc/multidevice/device_mesh.h @@ -49,10 +49,9 @@ class DeviceMesh final { return std::find(vector_.begin(), vector_.end(), device) != vector_.end(); } - // returns the index of device in the mesh. // returns -1 if device is not present. - int64_t idxOf(const DeviceIdxType device) { + int64_t idxOf(const DeviceIdxType device) const { auto it = std::find(vector_.begin(), vector_.end(), device); if (it != vector_.end()) { return std::distance(vector_.begin(), it); diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index 7b013ac83b4..d073bddea0e 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -52,7 +52,7 @@ void MultiDeviceTest::SetUp() { auto sharded_dim = getShardedAxis(tv); int i = tv->getDeviceMesh().idxOf(deviceId); // TODO: returning slice 0 temporarily when device is not in the mesh. - i = (i < 0) : 0 ? i; + i = (i < 0) ? 0 : i; return tensor.slice(sharded_dim, i, i + 1).contiguous(); } From 8d8db4b345a0efa47ba6114ce10894dbe7ea43bf Mon Sep 17 00:00:00 2001 From: mcowan Date: Thu, 9 May 2024 20:52:45 +0000 Subject: [PATCH 3/3] use helper functions --- csrc/multidevice/device_mesh.h | 7 ++++++- csrc/multidevice/lower_communication.cpp | 23 +++++++++++------------ tests/cpp/multidevice.cpp | 2 +- tests/cpp/test_multidevice_pipeline.cpp | 8 ++++---- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/csrc/multidevice/device_mesh.h b/csrc/multidevice/device_mesh.h index d2d547c7bcf..a620742f6ff 100644 --- a/csrc/multidevice/device_mesh.h +++ b/csrc/multidevice/device_mesh.h @@ -35,7 +35,7 @@ class DeviceMesh final { std::string toString() const; // returns the number of devices in the mesh - int64_t size() { + int64_t size() const { return vector_.size(); } @@ -59,6 +59,11 @@ class DeviceMesh final { return -1; } + // Returns the device at a particular index in the mesh + DeviceIdxType at(int64_t index) const { + return vector_.at(index); + } + bool operator==(const DeviceMesh& other) const { return vector_ == other.vector(); } diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 7fb6d4c7f2d..33500696014 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -83,7 +83,7 @@ void lowerToScatter( std::vector>& comms) { // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - auto root = input_tv->getDeviceMesh().vector().at(0); + auto root = input_tv->getDeviceMesh().at(0); if (!isDeviceInvolved(my_device_index, root, receiver_mesh)) { return; } @@ -159,7 +159,7 @@ void lowerToBroadcastOrP2P( } auto params = createParamsForBroadcastOrP2P(my_device_index, root, mesh); std::shared_ptr comm; - if (mesh.vector().size() == 1) { + if (mesh.size() == 1) { comm = std::make_shared(std::move(params)); } else { comm = std::make_shared(std::move(params)); @@ -182,20 +182,20 @@ void lowerToBroadcastOrP2P( if (is_sharded) { // if the inputs and ouputs are parallelized, // we create as many Broadcast as that will be handled in parallel - for (auto i : c10::irange(sender_mesh.vector().size())) { + for (auto i : c10::irange(sender_mesh.size())) { NVF_ERROR( - sender_mesh.vector().size() == receiver_mesh.vector().size(), + sender_mesh.size() == receiver_mesh.size(), "the receiver and sender meshes have different sizes"); lowerToBroadcastOrP2P( my_device_index, - sender_mesh.vector().at(i), - DeviceMesh({receiver_mesh.vector().at(i)}), + sender_mesh.at(i), + DeviceMesh({receiver_mesh.at(i)}), comms); } } else { // we arbitrarily choose the first device of the sender mesh to be the root lowerToBroadcastOrP2P( - my_device_index, sender_mesh.vector().at(0), receiver_mesh, comms); + my_device_index, sender_mesh.at(0), receiver_mesh, comms); } } @@ -307,13 +307,12 @@ std::vector> lowerCommunication( const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - const bool same_mesh = sender_mesh.vector() == receiver_mesh.vector(); + const bool same_mesh = sender_mesh == receiver_mesh; // Stores whether the I/O has its first axis parallelized on Didx - const bool is_input_sharded = - isSharded(input_tv) && sender_mesh.vector().size() > 1; + const bool is_input_sharded = isSharded(input_tv) && sender_mesh.size() > 1; const bool is_output_sharded = - isSharded(output_tv) && receiver_mesh.vector().size() > 1; + isSharded(output_tv) && receiver_mesh.size() > 1; auto original_expr = output_tv->definition(); NVF_ERROR( @@ -331,7 +330,7 @@ std::vector> lowerCommunication( BinaryOpType op_type = output_tv->definition()->as()->getReductionOpType(); NVF_ERROR( - is_input_sharded || sender_mesh.vector().size() == 1, + is_input_sharded || sender_mesh.size() == 1, "the comm input must be sharded in case of reduce.", "Insert a `set` before the reduction to reshard") if (is_output_sharded) { diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index 4a3adbb7b50..543394b7306 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -97,7 +97,7 @@ void MultiDeviceTest::SetUp() { return tensor; } auto sharded_dim = getShardedAxis(tv); - int i = tv->getDeviceMesh().idxOf(deviceId); + auto i = tv->getDeviceMesh().idxOf(deviceId); // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; return tensor.slice(sharded_dim, i, i + 1).contiguous(); diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index fa45045aee2..6c114448744 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -163,13 +163,13 @@ TEST_P(PipelineTestTwoStages, Communication) { std::vector unsharded_input_sizes = {3, 2, 3, 5}; if (is_stage0_sharded) { - unsharded_input_sizes[sharded_dim] = mesh0.vector().size(); + unsharded_input_sizes[sharded_dim] = mesh0.size(); } if (is_stage1_sharded) { - unsharded_input_sizes[sharded_dim] = mesh1.vector().size(); + unsharded_input_sizes[sharded_dim] = mesh1.size(); if (do_reduction) { - ASSERT_EQ(mesh0.vector().size(), mesh1.vector().size()); - unsharded_input_sizes[sharded_dim + 1] = mesh1.vector().size(); + ASSERT_EQ(mesh0.size(), mesh1.size()); + unsharded_input_sizes[sharded_dim + 1] = mesh1.size(); } }