diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 5979e424c2e..61e45fe4684 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -53,27 +53,26 @@ class DistributedMatmulTest : public MultiDeviceTest { MmaLayout layout, int M, int N, - int K) { + int K, + c10::ScalarType dtype) { int local_rank = communicator->local_rank(); c10::ScalarType type = c10::ScalarType::Half; auto a = matmulAtInput2D( layout, TensorMatmulPos::A, type, M, N, K, 0, local_rank); auto b = matmulAtInput2D( layout, TensorMatmulPos::B, type, M, N, K, 0, local_rank); - auto c = - atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(at::kFloat); + auto c = atMatmul(a.to(at::kDouble), b.to(at::kDouble), layout).to(dtype); return std::make_tuple(a, b, c); } }; -TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { +TEST_F(DistributedMatmulTest, MulSum_LayoutTN_NoComms) { // MmaLayout::TN A(T), B(N), C(T) // A and C are sharded on dimension M // Tests local matmul with no communication std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto mesh = DeviceMesh::createForNumDevices(communicator->size()); - int M = 256, N = 64, K = 64; int Mo = num_devices_; int Mi = M / Mo; @@ -98,13 +97,80 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { tv->setDeviceMesh(mesh); } b->setDeviceMesh(mesh); + // TODO: If c's allocation domain isn't set, it will fail validation at + // csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be + // from a contiguous inner most position. + c->setAllocationDomain(c->getLoopDomain(), true); + auto [in0, in1, out] = getInputsAndReferenceOutputs( + MmaLayout::TN, M, N, K, /*dtype=*/at::kFloat); + in0 = in0.view({Mo, Mi, K}); + out = out.view({Mo, Mi, N}); + std::vector inputs = { + shardTensor(in0, a, communicator->deviceId()), in1}; + auto expected_output = shardTensor(out, c, communicator->deviceId()); + MultiDeviceExecutor runtime( + std::move(fusion), *communicator, executor_params_); + auto outputs = runtime.runWithInput(inputs); + testValidate( + runtime.completeFusion(), + outputs, + inputs, + {expected_output}, + __LINE__, + __FILE__); + + std::vector fecs = runtime.getFusionExecutorCaches(); + EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_FALSE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .front() + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::Matmul); +} + +TEST_F(DistributedMatmulTest, Matmul_LayoutTN_NoComms) { + // MmaLayout::TN A(T), B(N), C(T) + // A and C are sharded on dimension M + // Tests local matmul with no communication + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto mesh = DeviceMesh::createForNumDevices(communicator->size()); + + int M = 256, N = 64, K = 64; + int Mo = num_devices_; + int Mi = M / Mo; + std::vector a_shape = {Mo, Mi, K}; + std::vector b_shape = {N, K}; + + TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K) + TensorView* b = makeContigTensor(2, DataType::Half); // (N,K) + TensorView* b_t = transpose(b, 0, 1); // (K,N) + TensorView* c = matmul(a, b_t); //(Mo,Mi,N,r) + + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + + // Sharding M dimension + auto all_sharded_tvs = {a, c}; + for (auto tv : all_sharded_tvs) { + tv->axis(0)->parallelize(ParallelType::DIDx); + tv->setDeviceMesh(mesh); + } + b->setDeviceMesh(mesh); // TODO: If c's allocation domain isn't set, it will fail validation at // csrc/device_lower/validation.cpp:419, Vectorized dim for consumer has to be // from a contiguous inner most position. c->setAllocationDomain(c->getLoopDomain(), true); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Mo, Mi, K}); out = out.view({Mo, Mi, N}); std::vector inputs = { @@ -125,9 +191,19 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } -TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { +TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { // MmaLayout::TN matmul A(T), B(N), C(T) // A is sharded on dimension M // Tests local matmul + allgather @@ -143,10 +219,8 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { TensorView* a = makeContigTensor(3, DataType::Half); // (Mo,Mi,K) TensorView* b = makeContigTensor(2, DataType::Half); // (N,K) - TensorView* a_b = broadcast(a, {false, false, true, false}); // (Mo,Mi,b,K) - TensorView* b_b = broadcast(b, {true, true, false, false}); // (b,b,N,K) - TensorView* ab = mul(a_b, b_b); // (Mo,Mi,N,K) - TensorView* c0 = sum(ab, {-1}); // (Mo,Mi,N,r) + TensorView* b_t = transpose(b, 0, 1); // (K,N) + TensorView* c0 = matmul(a, b_t); //(Mo,Mi,N,r) TensorView* c = set(c0); fusion->addInput(a); @@ -154,7 +228,7 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { fusion->addOutput(c); // Sharding M dimension - auto all_sharded_tvs = {a, a_b, b_b, ab, c0}; + auto all_sharded_tvs = {a, c0}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); tv->setDeviceMesh(mesh); @@ -162,7 +236,8 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { b->setDeviceMesh(mesh); c->setDeviceMesh(mesh); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::TN, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Mo, Mi, K}); out = out.view({Mo, Mi, N}); @@ -183,9 +258,19 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } -TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { +TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { // MmaLayout::NT matmul A(N), B(T), C(T) // Sharding: A, B are sharded along K. C is replicated. // Tests local matmul + allreduce @@ -200,13 +285,9 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { TensorView* a = makeContigTensor(3, DataType::Half); // (Ko,Ki,M) TensorView* b = makeContigTensor(3, DataType::Half); // (Ko,Ki,N) - // Transpose into TN layout, keep Ko (device axis) as the outermost. + // Transpose into TT layout, keep Ko (device axis) as the outermost. TensorView* a_t = transpose(a, 1, 2); // (Ko,M,Ki) - TensorView* b_t = transpose(b, 1, 2); // (Ko,N,Ki) - TensorView* a_b = broadcast(a_t, {false, false, true, false}); // (Ko,M,b,Ki) - TensorView* b_b = broadcast(b_t, {false, true, false, false}); // (Ko,b,N,Ki) - TensorView* ab = mul(a_b, b_b); // (Ko,M,N,Ki) - TensorView* c0 = sum(ab, {-1}); // (Ko,M,N,r) + TensorView* c0 = matmul(a_t, b); // (Ko,M,N,r) TensorView* c = sum(c0, {0}); // (r,M,N) fusion->addInput(a); @@ -214,14 +295,15 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { fusion->addOutput(c); // Parallelize K on all inputs and intermediates. - auto all_sharded_tvs = {a, b, a_t, b_t, a_b, b_b, ab, c0}; + auto all_sharded_tvs = {a, b, a_t, c0}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); tv->setDeviceMesh(mesh); } c->setDeviceMesh(mesh); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Ko, Ki, M}); in1 = in1.view({Ko, Ki, N}); std::vector inputs = { @@ -237,9 +319,19 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } -TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { +TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { // MmaLayout::NT matmul A(N), B(T), C(T) // A, B are sharded on K. C is sharded on M // Tests local matmul + reduce scatter @@ -256,14 +348,10 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { TensorView* a = makeContigTensor(3, DataType::Half); // (Ko,Ki,M) TensorView* b = makeContigTensor(3, DataType::Half); // (Ko,Ki,N) TensorView* a_t = transpose(a, 1, 2); // (Ko, M, Ki) - TensorView* b_t = transpose(b, 1, 2); // (Ko, N, Ki) - TensorView* a_b = broadcast(a_t, {false, false, true, false}); // (Ko,M,b,Ki) - TensorView* b_b = broadcast(b_t, {false, true, false, false}); // (Ko,b,N,Ki) - TensorView* ab = mul(a_b, b_b); // (Ko,M,N,Ki) - TensorView* c0 = sum(ab, {-1}); // (Ko,M,N,r) + TensorView* c0 = matmul(a_t, b); // (Ko,M,N,r) c0 = segment_set(c0); - std::vector orig_size = {K, M, N}; - std::vector new_size = {K, Mo, Mi, N}; + std::vector orig_size = {Ko, M, N}; + std::vector new_size = {Ko, Mo, Mi, N}; TensorView* c1 = reshape(c0, orig_size, new_size); TensorView* c = sum(c1, {0}); @@ -272,7 +360,7 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { fusion->addOutput(c); // Sharding K dimension of all inputs and intermediates. - auto all_sharded_tvs = {a, b, a_t, b_t, a_b, b_b, ab, c0, c1}; + auto all_sharded_tvs = {a, b, a_t, c0, c1}; for (auto tv : all_sharded_tvs) { tv->axis(0)->parallelize(ParallelType::DIDx); tv->setDeviceMesh(mesh); @@ -281,7 +369,8 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { c->setDeviceMesh(mesh); c->axis(1)->parallelize(ParallelType::DIDx); - auto [in0, in1, out] = getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K); + auto [in0, in1, out] = + getInputsAndReferenceOutputs(MmaLayout::NT, M, N, K, /*dtype=*/at::kHalf); in0 = in0.view({Ko, Ki, M}); in1 = in1.view({Ko, Ki, N}); out = out.view({Mo, Mi, N}); @@ -304,5 +393,15 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) { std::vector fecs = runtime.getFusionExecutorCaches(); EXPECT_EQ(fecs.size(), 1); + + const FusionKernelRuntime* kernel_runtime = + fecs.front()->getMostRecentKernelRuntime(); + EXPECT_TRUE(kernel_runtime->isSegmented()); + + ScheduleHeuristic heuristic = kernel_runtime->schedulerHeuristics() + ->heuristicsList() + .at(1) + ->heuristic(); + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } } // namespace nvfuser