Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 131 additions & 32 deletions tests/cpp/test_multidevice_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,26 @@ class DistributedMatmulTest : public MultiDeviceTest {
MmaLayout layout,
int M,
int N,
int K) {
int K,
c10::ScalarType dtype) {
int local_rank = communicator->local_rank();
c10::ScalarType type = c10::ScalarType::Half;
auto a = matmulAtInput2D(
layout, TensorMatmulPos::A, type, M, N, K, 0, local_rank);
auto b = matmulAtInput2D(
layout, TensorMatmulPos::B, type, M, N, K, 0, local_rank);
auto c =
atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kFloat);
auto c = atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(dtype);
return std::make_tuple(a, b, c);
}
};

TEST_F(DistributedMatmulTest, LayoutTN_NoComms) {
TEST_F(DistributedMatmulTest, MulSum_LayoutTN_NoComms) {
// MmaLayout::TN A(T), B(N), C(T)
// A and C are sharded on dimension M
// Tests local matmul with no communication
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto mesh = DeviceMesh::createForNumDevices(communicator->size());

int M = 256, N = 64, K = 64;
int Mo = num_devices_;
int Mi = M / Mo;
Expand All @@ -98,13 +97,80 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) {
tv->setDeviceMesh(mesh);
}
b->setDeviceMesh(mesh);
// TODO: If c's allocation domain isn't set, it will fail validation at
// csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be
// from a contiguous inner most position.
c->setAllocationDomain(c->getLoopDomain(), true);
auto [in0, in1, out] = getInputsAndReferenceOutputs(
MmaLayout::TN, M, N, K, /*dtype=*/at::kFloat);
in0 = in0.view({Mo, Mi, K});
out = out.view({Mo, Mi, N});
std::vector<c10::IValue> inputs = {
shardTensor(in0, a, communicator->deviceId()), in1};
auto expected_output = shardTensor(out, c, communicator->deviceId());
MultiDeviceExecutor runtime(
std::move(fusion), *communicator, executor_params_);
auto outputs = runtime.runWithInput(inputs);
testValidate(
runtime.completeFusion(),
outputs,
inputs,
{expected_output},
__LINE__,
__FILE__);

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);

const FusionKernelRuntime* kernel_runtime =
fecs.front()->getMostRecentKernelRuntime();
EXPECT_FALSE(kernel_runtime->isSegmented());

ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics()
->heuristicsList()
.front()
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::Matmul);
}

TEST_F(DistributedMatmulTest, Matmul_LayoutTN_NoComms) {
// MmaLayout::TN A(T), B(N), C(T)
// A and C are sharded on dimension M
// Tests local matmul with no communication
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto mesh = DeviceMesh::createForNumDevices(communicator->size());

int M = 256, N = 64, K = 64;
int Mo = num_devices_;
int Mi = M / Mo;
std::vector<int> a_shape = {Mo, Mi, K};
std::vector<int> b_shape = {N, K};

TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K)
TensorView* b = makeContigTensor(2, DataType::Half); // (N,K)
TensorView* b_t = transpose(b, 0, 1); // (K,N)
TensorView* c = matmul(a, b_t); //(Mo,Mi,N,r)

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

// Sharding M dimension
auto all_sharded_tvs = {a, c};
for (auto tv : all_sharded_tvs) {
tv->axis(0)->parallelize(ParallelType::DIDx);
tv->setDeviceMesh(mesh);
}
b->setDeviceMesh(mesh);

// TODO: If c's allocation domain isn't set, it will fail validation at
// csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be
// from a contiguous inner most position.
c->setAllocationDomain(c->getLoopDomain(), true);

auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K);
auto [in0, in1, out] =
getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K, /*dtype=*/at::kHalf);
in0 = in0.view({Mo, Mi, K});
out = out.view({Mo, Mi, N});
std::vector<c10::IValue> inputs = {
Expand All @@ -125,9 +191,19 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) {

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);

const FusionKernelRuntime* kernel_runtime =
fecs.front()->getMostRecentKernelRuntime();
EXPECT_TRUE(kernel_runtime->isSegmented());

ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics()
->heuristicsList()
.at(1)
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval);
}

TEST_F(DistributedMatmulTest, LayoutTN_Allgather) {
TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) {
// MmaLayout::TN matmul A(T), B(N), C(T)
// A is sharded on dimension M
// Tests local matmul + allgather
Expand All @@ -143,26 +219,25 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) {

TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K)
TensorView* b = makeContigTensor(2, DataType::Half); // (N,K)
TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K)
TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K)
TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K)
TensorView* c0 = sum(ab, {-1}); // (Mo,Mi,N,r)
TensorView* b_t = transpose(b, 0, 1); // (K,N)
TensorView* c0 = matmul(a, b_t); //(Mo,Mi,N,r)
TensorView* c = set(c0);

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

// Sharding M dimension
auto all_sharded_tvs = {a, a_b, b_b, ab, c0};
auto all_sharded_tvs = {a, c0};
for (auto tv : all_sharded_tvs) {
tv->axis(0)->parallelize(ParallelType::DIDx);
tv->setDeviceMesh(mesh);
}
b->setDeviceMesh(mesh);
c->setDeviceMesh(mesh);

auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K);
auto [in0, in1, out] =
getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K, /*dtype=*/at::kHalf);
in0 = in0.view({Mo, Mi, K});
out = out.view({Mo, Mi, N});

Expand All @@ -183,9 +258,19 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) {

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);

const FusionKernelRuntime* kernel_runtime =
fecs.front()->getMostRecentKernelRuntime();
EXPECT_TRUE(kernel_runtime->isSegmented());

ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics()
->heuristicsList()
.at(1)
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval);
}

TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) {
TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) {
// MmaLayout::NT matmul A(N), B(T), C(T)
// Sharding: A, B are sharded along K. C is replicated.
// Tests local matmul + allreduce
Expand All @@ -200,28 +285,25 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) {

TensorView* a = makeContigTensor(3, DataType::Half); // (Ko,Ki,M)
TensorView* b = makeContigTensor(3, DataType::Half); // (Ko,Ki,N)
// Transpose into TN layout, keep Ko (device axis) as the outermost.
// Transpose into TT layout, keep Ko (device axis) as the outermost.
TensorView* a_t = transpose(a, 1, 2); // (Ko,M,Ki)
TensorView* b_t = transpose(b, 1, 2); // (Ko,N,Ki)
TensorView* a_b = broadcast(a_t, {false, false, true, false}); // (Ko,M,b,Ki)
TensorView* b_b = broadcast(b_t, {false, true, false, false}); // (Ko,b,N,Ki)
TensorView* ab = mul(a_b, b_b); // (Ko,M,N,Ki)
TensorView* c0 = sum(ab, {-1}); // (Ko,M,N,r)
TensorView* c0 = matmul(a_t, b); // (Ko,M,N,r)
TensorView* c = sum(c0, {0}); // (r,M,N)

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

// Parallelize K on all inputs and intermediates.
auto all_sharded_tvs = {a, b, a_t, b_t, a_b, b_b, ab, c0};
auto all_sharded_tvs = {a, b, a_t, c0};
for (auto tv : all_sharded_tvs) {
tv->axis(0)->parallelize(ParallelType::DIDx);
tv->setDeviceMesh(mesh);
}
c->setDeviceMesh(mesh);

auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K);
auto [in0, in1, out] =
getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K, /*dtype=*/at::kHalf);
in0 = in0.view({Ko, Ki, M});
in1 = in1.view({Ko, Ki, N});
std::vector<c10::IValue> inputs = {
Expand All @@ -237,9 +319,19 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) {

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);

const FusionKernelRuntime* kernel_runtime =
fecs.front()->getMostRecentKernelRuntime();
EXPECT_TRUE(kernel_runtime->isSegmented());

ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics()
->heuristicsList()
.at(1)
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval);
}

TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) {
TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) {
// MmaLayout::NT matmul A(N), B(T), C(T)
// A, B are sharded on K. C is sharded on M
// Tests local matmul + reduce scatter
Expand All @@ -256,14 +348,10 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) {
TensorView* a = makeContigTensor(3, DataType::Half); // (Ko,Ki,M)
TensorView* b = makeContigTensor(3, DataType::Half); // (Ko,Ki,N)
TensorView* a_t = transpose(a, 1, 2); // (Ko, M, Ki)
TensorView* b_t = transpose(b, 1, 2); // (Ko, N, Ki)
TensorView* a_b = broadcast(a_t, {false, false, true, false}); // (Ko,M,b,Ki)
TensorView* b_b = broadcast(b_t, {false, true, false, false}); // (Ko,b,N,Ki)
TensorView* ab = mul(a_b, b_b); // (Ko,M,N,Ki)
TensorView* c0 = sum(ab, {-1}); // (Ko,M,N,r)
TensorView* c0 = matmul(a_t, b); // (Ko,M,N,r)
c0 = segment_set(c0);
std::vector<int64_t> orig_size = {K, M, N};
std::vector<int64_t> new_size = {K, Mo, Mi, N};
std::vector<int64_t> orig_size = {Ko, M, N};
std::vector<int64_t> new_size = {Ko, Mo, Mi, N};
TensorView* c1 = reshape(c0, orig_size, new_size);
TensorView* c = sum(c1, {0});

Expand All @@ -272,7 +360,7 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) {
fusion->addOutput(c);

// Sharding K dimension of all inputs and intermediates.
auto all_sharded_tvs = {a, b, a_t, b_t, a_b, b_b, ab, c0, c1};
auto all_sharded_tvs = {a, b, a_t, c0, c1};
for (auto tv : all_sharded_tvs) {
tv->axis(0)->parallelize(ParallelType::DIDx);
tv->setDeviceMesh(mesh);
Expand All @@ -281,7 +369,8 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) {
c->setDeviceMesh(mesh);
c->axis(1)->parallelize(ParallelType::DIDx);

auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K);
auto [in0, in1, out] =
getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K, /*dtype=*/at::kHalf);
in0 = in0.view({Ko, Ki, M});
in1 = in1.view({Ko, Ki, N});
out = out.view({Mo, Mi, N});
Expand All @@ -304,5 +393,15 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) {

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);

const FusionKernelRuntime* kernel_runtime =
fecs.front()->getMostRecentKernelRuntime();
EXPECT_TRUE(kernel_runtime->isSegmented());

ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics()
->heuristicsList()
.at(1)
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval);
}
} // namespace nvfuser