Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions csrc/multidevice/device_mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class DeviceMesh final {

std::string toString() const;

// returns the number of devices in the mesh
int64_t size() const {
return static_cast<int64_t>(vector_.size());
}

// returns a vector containing the device indices of the mesh
const auto& vector() const {
return vector_;
Expand All @@ -44,6 +49,21 @@ class DeviceMesh final {
return std::find(vector_.begin(), vector_.end(), device) != vector_.end();
}

// returns the index of device in the mesh.
// returns -1 if device is not present.
int64_t idxOf(const DeviceIdxType device) const {
auto it = std::find(vector_.begin(), vector_.end(), device);
if (it != vector_.end()) {
return std::distance(vector_.begin(), it);
}
return -1;
}

// Returns the device at a particular index in the mesh
DeviceIdxType at(int64_t index) const {
return vector_.at(index);
}

bool operator==(const DeviceMesh& other) const {
return vector_ == other.vector();
}
Expand Down
21 changes: 10 additions & 11 deletions csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void lowerToScatter(
std::vector<Communication*>& comms) {
// we arbitrarily choose the first device of the sender mesh to be the root
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
auto root = input_tv->getDeviceMesh().vector().at(0);
auto root = input_tv->getDeviceMesh().at(0);
if (!isDeviceInvolved(my_device_index, root, receiver_mesh)) {
return;
}
Expand Down Expand Up @@ -181,20 +181,20 @@ void lowerToBroadcastOrP2P(
if (is_sharded) {
// if the inputs and ouputs are parallelized,
// we create as many Broadcast as that will be handled in parallel
for (auto i : c10::irange(sender_mesh.vector().size())) {
for (auto i : c10::irange(sender_mesh.size())) {
NVF_ERROR(
sender_mesh.vector().size() == receiver_mesh.vector().size(),
sender_mesh.size() == receiver_mesh.size(),
"the receiver and sender meshes have different sizes");
lowerToBroadcastOrP2P(
my_device_index,
sender_mesh.vector().at(i),
DeviceMesh({receiver_mesh.vector().at(i)}),
sender_mesh.at(i),
DeviceMesh({receiver_mesh.at(i)}),
comms);
}
} else {
// we arbitrarily choose the first device of the sender mesh to be the root
lowerToBroadcastOrP2P(
my_device_index, sender_mesh.vector().at(0), receiver_mesh, comms);
my_device_index, sender_mesh.at(0), receiver_mesh, comms);
}
}

Expand Down Expand Up @@ -309,13 +309,12 @@ std::vector<Communication*> lowerCommunication(

const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
const bool same_mesh = sender_mesh.vector() == receiver_mesh.vector();
const bool same_mesh = sender_mesh == receiver_mesh;

// Stores whether the I/O has its first axis parallelized on Didx
const bool is_input_sharded =
isSharded(input_tv) && sender_mesh.vector().size() > 1;
const bool is_input_sharded = isSharded(input_tv) && sender_mesh.size() > 1;
const bool is_output_sharded =
isSharded(output_tv) && receiver_mesh.vector().size() > 1;
isSharded(output_tv) && receiver_mesh.size() > 1;

auto original_expr = output_tv->definition();
NVF_ERROR(
Expand All @@ -333,7 +332,7 @@ std::vector<Communication*> lowerCommunication(
BinaryOpType op_type =
output_tv->definition()->as<ReductionOp>()->getReductionOpType();
NVF_ERROR(
is_input_sharded || sender_mesh.vector().size() == 1,
is_input_sharded || sender_mesh.size() == 1,
"the comm input must be sharded in case of reduce.",
"Insert a `set` before the reduction to reshard")
if (is_output_sharded) {
Expand Down
9 changes: 3 additions & 6 deletions tests/cpp/multidevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,9 @@ void MultiDeviceTest::SetUp() {
return tensor;
}
auto sharded_dim = getShardedAxis(tv);
int i = 0;
const auto& devices = tv->getDeviceMesh().vector();
auto it = std::find(devices.begin(), devices.end(), deviceId);
if (it != devices.end()) {
i = std::distance(devices.begin(), it);
}
auto i = tv->getDeviceMesh().idxOf(deviceId);
// TODO: returning slice 0 temporarily when device is not in the mesh.
i = (i < 0) ? 0 : i;
return tensor.slice(sharded_dim, i, i + 1).contiguous();
}

Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/test_multidevice_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ TEST_P(PipelineTestTwoStages, Communication) {

std::vector<int64_t> unsharded_input_sizes = {3, 2, 3, 5};
if (is_stage0_sharded) {
unsharded_input_sizes[sharded_dim] = mesh0.vector().size();
unsharded_input_sizes[sharded_dim] = mesh0.size();
}
if (is_stage1_sharded) {
unsharded_input_sizes[sharded_dim] = mesh1.vector().size();
unsharded_input_sizes[sharded_dim] = mesh1.size();
if (do_reduction) {
ASSERT_EQ(mesh0.vector().size(), mesh1.vector().size());
unsharded_input_sizes[sharded_dim + 1] = mesh1.vector().size();
ASSERT_EQ(mesh0.size(), mesh1.size());
unsharded_input_sizes[sharded_dim + 1] = mesh1.size();
}
}

Expand Down