diff --git a/csrc/multidevice/device_mesh.h b/csrc/multidevice/device_mesh.h index 19625de1197..1305926d23d 100644 --- a/csrc/multidevice/device_mesh.h +++ b/csrc/multidevice/device_mesh.h @@ -34,6 +34,11 @@ class DeviceMesh final { std::string toString() const; + // returns the number of devices in the mesh + int64_t size() const { + return static_cast(vector_.size()); + } + // returns a vector containing the device indices of the mesh const auto& vector() const { return vector_; @@ -44,6 +49,21 @@ class DeviceMesh final { return std::find(vector_.begin(), vector_.end(), device) != vector_.end(); } + // returns the index of device in the mesh. + // returns -1 if device is not present. + int64_t idxOf(const DeviceIdxType device) const { + auto it = std::find(vector_.begin(), vector_.end(), device); + if (it != vector_.end()) { + return std::distance(vector_.begin(), it); + } + return -1; + } + + // Returns the device at a particular index in the mesh + DeviceIdxType at(int64_t index) const { + return vector_.at(index); + } + bool operator==(const DeviceMesh& other) const { return vector_ == other.vector(); } diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 6de4bd78376..726569aea15 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -86,7 +86,7 @@ void lowerToScatter( std::vector& comms) { // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - auto root = input_tv->getDeviceMesh().vector().at(0); + auto root = input_tv->getDeviceMesh().at(0); if (!isDeviceInvolved(my_device_index, root, receiver_mesh)) { return; } @@ -181,20 +181,20 @@ void lowerToBroadcastOrP2P( if (is_sharded) { // if the inputs and ouputs are parallelized, // we create as many Broadcast as that will be handled in parallel - for (auto i : c10::irange(sender_mesh.vector().size())) { + for (auto i : c10::irange(sender_mesh.size())) { NVF_ERROR( - sender_mesh.vector().size() == receiver_mesh.vector().size(), + sender_mesh.size() == receiver_mesh.size(), "the receiver and sender meshes have different sizes"); lowerToBroadcastOrP2P( my_device_index, - sender_mesh.vector().at(i), - DeviceMesh({receiver_mesh.vector().at(i)}), + sender_mesh.at(i), + DeviceMesh({receiver_mesh.at(i)}), comms); } } else { // we arbitrarily choose the first device of the sender mesh to be the root lowerToBroadcastOrP2P( - my_device_index, sender_mesh.vector().at(0), receiver_mesh, comms); + my_device_index, sender_mesh.at(0), receiver_mesh, comms); } } @@ -309,13 +309,12 @@ std::vector lowerCommunication( const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - const bool same_mesh = sender_mesh.vector() == receiver_mesh.vector(); + const bool same_mesh = sender_mesh == receiver_mesh; // Stores whether the I/O has its first axis parallelized on Didx - const bool is_input_sharded = - isSharded(input_tv) && sender_mesh.vector().size() > 1; + const bool is_input_sharded = isSharded(input_tv) && sender_mesh.size() > 1; const bool is_output_sharded = - isSharded(output_tv) && receiver_mesh.vector().size() > 1; + isSharded(output_tv) && receiver_mesh.size() > 1; auto original_expr = output_tv->definition(); NVF_ERROR( @@ -333,7 +332,7 @@ std::vector lowerCommunication( BinaryOpType op_type = output_tv->definition()->as()->getReductionOpType(); NVF_ERROR( - is_input_sharded || sender_mesh.vector().size() == 1, + is_input_sharded || sender_mesh.size() == 1, "the comm input must be sharded in case of reduce.", "Insert a `set` before the reduction to reshard") if (is_output_sharded) { diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index 31b4cc8add5..543394b7306 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -97,12 +97,9 @@ void MultiDeviceTest::SetUp() { return tensor; } auto sharded_dim = getShardedAxis(tv); - int i = 0; - const auto& devices = tv->getDeviceMesh().vector(); - auto it = std::find(devices.begin(), devices.end(), deviceId); - if (it != devices.end()) { - i = std::distance(devices.begin(), it); - } + auto i = tv->getDeviceMesh().idxOf(deviceId); + // TODO: returning slice 0 temporarily when device is not in the mesh. + i = (i < 0) ? 0 : i; return tensor.slice(sharded_dim, i, i + 1).contiguous(); } diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index fa45045aee2..6c114448744 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -163,13 +163,13 @@ TEST_P(PipelineTestTwoStages, Communication) { std::vector unsharded_input_sizes = {3, 2, 3, 5}; if (is_stage0_sharded) { - unsharded_input_sizes[sharded_dim] = mesh0.vector().size(); + unsharded_input_sizes[sharded_dim] = mesh0.size(); } if (is_stage1_sharded) { - unsharded_input_sizes[sharded_dim] = mesh1.vector().size(); + unsharded_input_sizes[sharded_dim] = mesh1.size(); if (do_reduction) { - ASSERT_EQ(mesh0.vector().size(), mesh1.vector().size()); - unsharded_input_sizes[sharded_dim + 1] = mesh1.vector().size(); + ASSERT_EQ(mesh0.size(), mesh1.size()); + unsharded_input_sizes[sharded_dim + 1] = mesh1.size(); } }