From f4fd079634f15caa9e25658214db73df621e77f2 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 28 May 2024 23:05:12 +0000 Subject: [PATCH 1/5] use matmul API --- tests/cpp/test_multidevice_matmul.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 5979e424c2e..3fd10bafb2e 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -61,7 +61,7 @@ class DistributedMatmulTest : public MultiDeviceTest { auto b = matmulAtInput2D( layout, TensorMatmulPos::B, type, M, N, K, 0, local_rank); auto c = - atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kFloat); + atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kHalf); return std::make_tuple(a, b, c); } }; @@ -82,17 +82,19 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K) TensorView* b = makeContigTensor(2, DataType::Half); // (N,K) - TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K) - TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K) - TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K) - TensorView* c = sum(ab, {-1}); // (Mo,Mi,N,r) + // TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K) + // TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K) + // TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K) + // TensorView* c = sum(ab, {-1}); // (Mo,Mi,N,r) + TensorView* c = matmul(a, b); fusion->addInput(a); fusion->addInput(b); fusion->addOutput(c); // Sharding M dimension - auto all_sharded_tvs = {a, a_b, b_b, ab, c}; + // auto all_sharded_tvs = {a, a_b, b_b, ab, c}; + auto all_sharded_tvs = {a, c}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); tv->setDeviceMesh(mesh); From 8bb37e4509bab0c0ef41db309e27e2c95954602d Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 11 Jun 2024 02:34:38 +0000 Subject: [PATCH 2/5] use matmul --- tests/cpp/test_multidevice_matmul.cpp | 36 ++++++++------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 3fd10bafb2e..485b7ad3be9 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -82,18 +82,14 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K) TensorView* b = makeContigTensor(2, DataType::Half); // (N,K) - // TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K) - // TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K) - // TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K) - // TensorView* c = sum(ab, {-1}); // (Mo,Mi,N,r) - TensorView* c = matmul(a, b); + TensorView* b_t = transpose(b, 0, 1); // (K,N) + TensorView* c = matmul(a, b_t); //(Mo,Mi,N,r) fusion->addInput(a); fusion->addInput(b); fusion->addOutput(c); // Sharding M dimension - // auto all_sharded_tvs = {a, a_b, b_b, ab, c}; auto all_sharded_tvs = {a, c}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); @@ -141,14 +137,12 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { int Mo = num_devices_; int Mi = M / Mo; std::vector a_shape = {Mo, Mi, K}; - std::vector b_shape = {N, K}; + std::vector b_shape = {N,K}; TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K) TensorView* b = makeContigTensor(2, DataType::Half); // (N,K) - TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K) - TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K) - TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K) - TensorView* c0 = sum(ab, {-1}); // (Mo,Mi,N,r) + TensorView* b_t = transpose(b, 0, 1); // (K,N) + TensorView* c0 = matmul(a, b_t); //(Mo,Mi,N,r) TensorView* c = set(c0); fusion->addInput(a); @@ -156,7 +150,7 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { fusion->addOutput(c); // Sharding M dimension - auto all_sharded_tvs = {a, a_b, b_b, ab, c0}; + auto all_sharded_tvs = {a, c0}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); tv->setDeviceMesh(mesh); @@ -202,13 +196,9 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { TensorView* a = makeContigTensor(3, DataType::Half); // (Ko,Ki,M) TensorView* b = makeContigTensor(3, DataType::Half); // (Ko,Ki,N) - // Transpose into TN layout, keep Ko (device axis) as the outermost. + // Transpose into TT layout, keep Ko (device axis) as the outermost. TensorView* a_t = transpose(a, 1, 2); // (Ko,M,Ki) - TensorView* b_t = transpose(b, 1, 2); // (Ko,N,Ki) - TensorView* a_b = broadcast(a_t, {false, false, true, false}); // (Ko,M,b,Ki) - TensorView* b_b = broadcast(b_t, {false, true, false, false}); // (Ko,b,N,Ki) - TensorView* ab = mul(a_b, b_b); // (Ko,M,N,Ki) - TensorView* c0 = sum(ab, {-1}); // (Ko,M,N,r) + TensorView* c0 = matmul(a_t, b); // (Ko,M,N,r) TensorView* c = sum(c0, {0}); // (r,M,N) fusion->addInput(a); @@ -216,7 +206,7 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { fusion->addOutput(c); // Parallelize K on all inputs and intermediates. - auto all_sharded_tvs = {a, b, a_t, b_t, a_b, b_b, ab, c0}; + auto all_sharded_tvs = {a, b, a_t, c0}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); tv->setDeviceMesh(mesh); @@ -258,11 +248,7 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { TensorView* a = makeContigTensor(3, DataType::Half); // (Ko,Ki,M) TensorView* b = makeContigTensor(3, DataType::Half); // (Ko,Ki,N) TensorView* a_t = transpose(a, 1, 2); // (Ko, M, Ki) - TensorView* b_t = transpose(b, 1, 2); // (Ko, N, Ki) - TensorView* a_b = broadcast(a_t, {false, false, true, false}); // (Ko,M,b,Ki) - TensorView* b_b = broadcast(b_t, {false, true, false, false}); // (Ko,b,N,Ki) - TensorView* ab = mul(a_b, b_b); // (Ko,M,N,Ki) - TensorView* c0 = sum(ab, {-1}); // (Ko,M,N,r) + TensorView* c0 = matmul(a_t, b); // (Ko,M,N,r) c0 = segment_set(c0); std::vector orig_size = {K, M, N}; std::vector new_size = {K, Mo, Mi, N}; @@ -274,7 +260,7 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { fusion->addOutput(c); // Sharding K dimension of all inputs and intermediates. - auto all_sharded_tvs = {a, b, a_t, b_t, a_b, b_b, ab, c0, c1}; + auto all_sharded_tvs = {a, b, a_t, c0, c1}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); tv->setDeviceMesh(mesh); From c84db619cfe6d2498510b564ea70cf88cf2db7d3 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 12 Jun 2024 07:06:50 +0000 Subject: [PATCH 3/5] K->Ko --- tests/cpp/test_multidevice_matmul.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 485b7ad3be9..31f93763e00 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -137,7 +137,7 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { int Mo = num_devices_; int Mi = M / Mo; std::vector a_shape = {Mo, Mi, K}; - std::vector b_shape = {N,K}; + std::vector b_shape = {N, K}; TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K) TensorView* b = makeContigTensor(2, DataType::Half); // (N,K) @@ -250,8 +250,8 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { TensorView* a_t = transpose(a, 1, 2); // (Ko, M, Ki) TensorView* c0 = matmul(a_t, b); // (Ko,M,N,r) c0 = segment_set(c0); - std::vector orig_size = {K, M, N}; - std::vector new_size = {K, Mo, Mi, N}; + std::vector orig_size = {Ko, M, N}; + std::vector new_size = {Ko, Mo, Mi, N}; TensorView* c1 = reshape(c0, orig_size, new_size); TensorView* c = sum(c1, {0}); From eb5dd4574891db007123ae4fa39288a54c837bd3 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Sat, 15 Jun 2024 00:13:22 +0000 Subject: [PATCH 4/5] add back the mul sum test --- tests/cpp/test_multidevice_matmul.cpp | 80 +++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 31f93763e00..885dd572286 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -53,20 +53,74 @@ class DistributedMatmulTest : public MultiDeviceTest { MmaLayout layout, int M, int N, - int K) { + int K, + c10::ScalarType dtype) { int local_rank = communicator->local_rank(); c10::ScalarType type = c10::ScalarType::Half; auto a = matmulAtInput2D( layout, TensorMatmulPos::A, type, M, N, K, 0, local_rank); auto b = matmulAtInput2D( layout, TensorMatmulPos::B, type, M, N, K, 0, local_rank); - auto c = - atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kHalf); + auto c = atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(dtype); return std::make_tuple(a, b, c); } }; -TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { +TEST_F(DistributedMatmulTest, MulSum_LayoutTN_NoComms) { + // MmaLayout::TN A(T), B(N), C(T) + // A and C are sharded on dimension M + // Tests local matmul with no communication + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto mesh = DeviceMesh::createForNumDevices(communicator->size()); + int M = 256, N = 64, K = 64; + int Mo = num_devices_; + int Mi = M / Mo; + std::vector a_shape = {Mo, Mi, K}; + std::vector b_shape = {N, K}; + + TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K) + TensorView* b = makeContigTensor(2, DataType::Half); // (N,K) + TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K) + TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K) + TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K) + TensorView* c = sum(ab, {-1}); // (Mo,Mi,N,r) + + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + + // Sharding M dimension + auto all_sharded_tvs = {a, a_b, b_b, ab, c}; + for (auto tv : all_sharded_tvs) { + tv->axis(0)->parallelize(ParallelType::DIDx); + tv->setDeviceMesh(mesh); + } + b->setDeviceMesh(mesh); + // TODO: If c's allocation domain isn't set, it will fail validation at + // csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be + // from a contiguous inner most position. + c->setAllocationDomain(c->getLeafDomain(), true); + auto [in0, in1, out] = getInputsAndReferenceOutputs( + MmaLayout::TN, M, N, K, /*dtype=*/at::kFloat); + in0 = in0.view({Mo, Mi, K}); + out = out.view({Mo, Mi, N}); + std::vector inputs = { + shardTensor(in0, a, communicator->deviceId()), in1}; + auto expected_output = shardTensor(out, c, communicator->deviceId()); + MultiDeviceExecutor runtime( + std::move(fusion), *communicator, executor_params_); + auto outputs = runtime.runWithInput(inputs); + testValidate( + runtime.completeFusion(), + outputs, + inputs, + {expected_output}, + __LINE__, + __FILE__); +} + +TEST_F(DistributedMatmulTest, Matmul_LayoutTN_NoComms) { // MmaLayout::TN A(T), B(N), C(T) // A and C are sharded on dimension M // Tests local matmul with no communication @@ -102,7 +156,8 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { // from a contiguous inner most position. c->setAllocationDomain(c->getLoopDomain(), true); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Mo, Mi, K}); out = out.view({Mo, Mi, N}); std::vector inputs = { @@ -125,7 +180,7 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { EXPECT_EQ(fecs.size(), 1); } -TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { +TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { // MmaLayout::TN matmul A(T), B(N), C(T) // A is sharded on dimension M // Tests local matmul + allgather @@ -158,7 +213,8 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { b->setDeviceMesh(mesh); c->setDeviceMesh(mesh); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Mo, Mi, K}); out = out.view({Mo, Mi, N}); @@ -181,7 +237,7 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { EXPECT_EQ(fecs.size(), 1); } -TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { +TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { // MmaLayout::NT matmul A(N), B(T), C(T) // Sharding: A, B are sharded along K. C is replicated. // Tests local matmul + allreduce @@ -213,7 +269,8 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { } c->setDeviceMesh(mesh); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Ko, Ki, M}); in1 = in1.view({Ko, Ki, N}); std::vector inputs = { @@ -231,7 +288,7 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { EXPECT_EQ(fecs.size(), 1); } -TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { +TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { // MmaLayout::NT matmul A(N), B(T), C(T) // A, B are sharded on K. C is sharded on M // Tests local matmul + reduce scatter @@ -269,7 +326,8 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { c->setDeviceMesh(mesh); c->axis(1)->parallelize(ParallelType::DIDx); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Ko, Ki, M}); in1 = in1.view({Ko, Ki, N}); out = out.view({Mo, Mi, N}); From 49ebf3badabfdae199ad31077e6456ddbbf6d36b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Sat, 15 Jun 2024 01:00:08 +0000 Subject: [PATCH 5/5] check scheduler heuristic --- tests/cpp/test_multidevice_matmul.cpp | 55 ++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 885dd572286..61e45fe4684 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -100,7 +100,7 @@ TEST_F(DistributedMatmulTest, MulSum_LayoutTN_NoComms) { // TODO: If c's allocation domain isn't set, it will fail validation at // csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be // from a contiguous inner most position. - c->setAllocationDomain(c->getLeafDomain(), true); + c->setAllocationDomain(c->getLoopDomain(), true); auto [in0, in1, out] = getInputsAndReferenceOutputs( MmaLayout::TN, M, N, K, /*dtype=*/at::kFloat); in0 = in0.view({Mo, Mi, K}); @@ -118,6 +118,19 @@ TEST_F(DistributedMatmulTest, MulSum_LayoutTN_NoComms) { {expected_output}, __LINE__, __FILE__); + + std::vector fecs = runtime.getFusionExecutorCaches(); + EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_FALSE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .front() + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::Matmul); } TEST_F(DistributedMatmulTest, Matmul_LayoutTN_NoComms) { @@ -178,6 +191,16 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_NoComms) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { @@ -235,6 +258,16 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { @@ -286,6 +319,16 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { @@ -350,5 +393,15 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } } // namespace nvfuser