Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ c10::intrusive_ptr<c10d::Work> postScatter(
input_tensors.front().push_back(output_tensor);
continue;
}
input_tensors.front().push_back(
input_tensor.slice(0, j, j + 1).contiguous());
input_tensors.front().push_back(input_tensor.slice(0, j, j + 1));
j++;
}

Expand Down
3 changes: 3 additions & 0 deletions csrc/multidevice/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ MultiDeviceExecutor::MultiDeviceExecutor(
Communicator& comm,
MultiDeviceExecutorParams params)
: comm_(comm), params_(params) {
// Sharding PreSegmenter passes.
// Note: passes run before PreSegmenter optimization passes.
propagateShardings(fusion.get());
insertReshardings(fusion.get());
insertShardedAxisReordering(fusion.get());
setShardedAllocationDomain(fusion.get());
SegmentCandidateFinderOptions options{
.run_translate_welford = false,
.run_combine_reductions = false,
Expand Down
46 changes: 38 additions & 8 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ void insertReshardingsAfter(Fusion* fusion) {
}
}

void setShardedAllocationDomain(TensorView* tv) {
if (!tv->hasAllocation()) {
tv->setAllocationDomain(tv->getLeafDomain(), true);
}
}

} // namespace

void propagateShardings(Fusion* fusion) {
Expand All @@ -325,15 +331,16 @@ void propagateShardings(Fusion* fusion) {
auto outputs = ir_utils::filterByType<TensorView>(expr->outputs());
TensorView* input_with_mesh = nullptr;
for (auto tv : inputs) {
if (tv->hasDeviceMesh()) {
NVF_CHECK(
tv->hasDeviceMesh(),
"Tensor ",
tv->toString(),
" should be assigned a DeviceMesh");
if (input_with_mesh == nullptr) {
input_with_mesh = tv;
break;
}
}
NVF_ERROR(
input_with_mesh != nullptr,
"At least one input requires a DeviceMesh ",
expr->toString());

std::vector<TensorView*> outputs_without_mesh;
for (auto tv : outputs) {
if (!tv->hasDeviceMesh()) {
Expand All @@ -360,7 +367,7 @@ void insertShardedAxisReordering(Fusion* fusion) {
}
NVF_ERROR(
ir_utils::isTvOp(expr),
"Non-tv op is not supported : ",
"Non-tv op is not supported:",
expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
Expand All @@ -375,7 +382,8 @@ void insertShardedAxisReordering(Fusion* fusion) {
auto [shard_additions, shard_deletions] = getShardingChanges(expr);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis")
"Resharding expr can only support one axis:",
expr->toString())

// For gather operations i.e. ID goes from sharded to unsharded
// this will rematerialize a sharded axis.
Expand Down Expand Up @@ -464,6 +472,28 @@ void insertShardedAxisReordering(Fusion* fusion) {
}
}

void setShardedAllocationDomain(Fusion* fusion) {
for (Expr* expr : fusion->exprs()) {
if (!isResharding(expr)) {
continue;
}
for (TensorView* tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto c : tv->getContiguity()) {
if (c.has_value()) {
NVF_CHECK(
c.value(),
"Resharding expression input must be contiguous: ",
expr);
}
}
setShardedAllocationDomain(tv);
}
for (auto tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
setShardedAllocationDomain(tv);
}
}
}

int64_t requestedNumberOfDevices(Fusion* fusion) {
DeviceIdxType max_index = 0;
for (auto tv : ir_utils::allTvs(fusion)) {
Expand Down
7 changes: 7 additions & 0 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ void insertReshardings(Fusion* fusion);
// to the front so that communication operations are contiguous.
void insertShardedAxisReordering(Fusion* fusion);

// Resharding expressions are mapped to collective libraries which expect
// contiguous tensors and output contiguous buffers. This pass checks that
// inputs are contiguous and sets the allocation domain of inputs and outputs of
// all resharding expressions. This pass should run after all passes that add or
// update resharding expressions.
void setShardedAllocationDomain(Fusion* fusion);

// Returns the index of the a sharded axis if none return -1.
// TODO: Assumes no merges/splits on sharded axis.
int64_t getShardedAxis(TensorView*);
Expand Down
15 changes: 6 additions & 9 deletions tests/cpp/test_multidevice_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include <kernel_ir.h>
#include <mma_type.h>
#include <ops/all_ops.h>
#include <preseg_passes/allocation_order_inference.h>
#include <preseg_passes/optimization_pass.h>
#include <scheduler/mma_utils.h>
#include <scheduler/utils.h>
#include <tests/cpp/multidevice.h>
Expand All @@ -36,8 +34,7 @@ namespace nvfuser {

class DistributedMatmulTest : public MultiDeviceTest {
protected:
DistributedMatmulTest()
: num_devices_(communicator->size()), optimization_guard_(false) {}
DistributedMatmulTest() : num_devices_(communicator->size()) {}

void SetUp() {
MultiDeviceTest::SetUp();
Expand Down Expand Up @@ -67,11 +64,6 @@ class DistributedMatmulTest : public MultiDeviceTest {
atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kFloat);
return std::make_tuple(a, b, c);
}

private:
preseg_passes::OptimizationPassGuard<preseg_passes::AllocationDomainPass>
optimization_guard_;
DisableOptionsGuard option_guard_;
};

TEST_F(DistributedMatmulTest, LayoutTN_NoComms) {
Expand Down Expand Up @@ -107,6 +99,11 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) {
}
b->setDeviceMesh(mesh);

// TODO: If c's allocation domain isn't set, it will fail validation at
// csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be
// from a contiguous inner most position.
c->setAllocationDomain(c->getLeafDomain(), true);

auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K);
in0 = in0.view({Mo, Mi, K});
out = out.view({Mo, Mi, N});
Expand Down
9 changes: 5 additions & 4 deletions tests/cpp/test_multidevice_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ TEST_P(MultideviceShardingTest, UnshardedGlobalInput) {
input_size[sharded_dim] = num_devices;
input_size[sharded_output_dim] = num_devices;

TensorView* tv0 = creates_concrete_tensor ? makeConcreteTensor(input_size)
: makeSymbolicTensor(4);
TensorView* tv0 = creates_concrete_tensor
? makeContigConcreteTensor(input_size)
: makeContigTensor(4);
TensorView* tv1 = set(tv0);
TensorView* tv2 = add(tv1, tv1);
TensorView* tv3 = sum(tv2, {sharded_dim});
Expand Down Expand Up @@ -82,8 +83,8 @@ TEST_P(MultideviceShardingTest, ShardGlobalInput) {
unsharded_input_size[sharded_dim] = num_devices;

TensorView* tv0 = creates_concrete_tensor
? makeConcreteTensor(unsharded_input_size)
: makeSymbolicTensor(unsharded_input_size.size());
? makeContigConcreteTensor(unsharded_input_size)
: makeContigTensor(unsharded_input_size.size());
TensorView* tv1 = set(tv0);
TensorView* tv2 = add(tv1, tv1);
fusion->addInput(tv0);
Expand Down
56 changes: 55 additions & 1 deletion tests/cpp/test_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <fusion.h>

#include <gtest/gtest.h>

#include <fusion.h>
#include <multidevice/executor.h>
#include <multidevice/utils.h>
#include <ops/all_ops.h>
Expand Down Expand Up @@ -62,6 +64,58 @@ TEST_F(ShardingTest, PropagateSharding) {
EXPECT_TRUE(getTvsWithDifferentSharding(a, tvs).empty());
}

void isContiguous(TensorView* tv) {
EXPECT_TRUE(tv->hasAllocation());
auto contiguity = tv->getContiguity();
auto alloc_domain = tv->getAllocationDomain();
for (auto i : c10::irange(contiguity.size())) {
// TODO: This should eventually check that DeviceDim domains also has no
// value.
if (alloc_domain[i]->isReduction() || alloc_domain[i]->isBroadcast()) {
EXPECT_FALSE(contiguity[i].has_value());
} else {
EXPECT_TRUE(contiguity[i].value());
}
}
}

TEST_F(ShardingTest, ShardedAllocationDomain) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* a = makeContigTensor(3);
TensorView* b = makeContigTensor(3);
TensorView* c = add(a, b);
TensorView* d = sum(c, {1});

DeviceMesh mesh = DeviceMesh::createForNumDevices(3);
for (auto tv : {a, b, c, d}) {
tv->setDeviceMesh(mesh);
}

int sharded_dim = 1;
a->axis(sharded_dim)->parallelize(ParallelType::DIDx);
c->axis(sharded_dim)->parallelize(ParallelType::DIDx);
fusion.addInput(a);
fusion.addInput(b);
fusion.addOutput(d);

propagateShardings(&fusion);
insertReshardings(&fusion);
insertShardedAxisReordering(&fusion);
setShardedAllocationDomain(&fusion);
for (auto expr : fusion.exprs()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you like making the check more specific? For example, is it possible to check the following:

  • I'd expect there's only one resharding Expr, which is a sum
  • I'd also expect the input of that Expr has DID as the first IterDomain in the containing allocation domain.

if (isResharding(expr)) {
for (auto tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
isContiguous(tv);
}
for (auto tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
isContiguous(tv);
}
}
}
}

TEST_P(ShardingTest, ComputeIndex) {
const bool creates_concrete_tensor = GetParam();
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
Expand Down