diff --git a/csrc/options.cpp b/csrc/options.cpp index ee92208cc5d..2f91d9f5c7a 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -153,6 +153,7 @@ template <> std::unordered_map> Options< EnableOption>::getOptionsFromEnv() { const std::unordered_map available_options = { + {"fuse_matmul", EnableOption::FuseMatmul}, {"id_model", EnableOption::IdModel}, {"kernel_db", EnableOption::KernelDb}, {"kernel_profile", EnableOption::KernelProfile}, diff --git a/csrc/options.h b/csrc/options.h index b3455391f8a..8fbd9685930 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -90,6 +90,7 @@ enum class DebugDumpOption { //! These can be set through the `NVFUSER_ENABLE` environment variable //! enum class EnableOption { + FuseMatmul, //! Enable automatic fusion of matmul and linear ops IdModel, //! Enable IdModel KernelDb, //! Enable Kernel Database KernelProfile, //! Enable intra-kernel performance profiling diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index b25a290ce99..777452079cb 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -16,13 +16,18 @@ namespace nvfuser { // Check if the fusion has a single MatmulOp/LinearOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); - if (exprs.size() == 1 && - (exprs.front()->isA() || exprs.front()->isA())) { - return true; + if (!isOptionDisabled(DisableOption::MatmulExprEval)) { + if (exprs.size() == 1 && (exprs.front()->isOneOf())) { + return true; + } + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Fusion must contain a single expression of type MatmulOp or LinearOp"); + } else { + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Matmul ATen evaluation was disabled by NVFUSER_DISABLE=matmul_expr_eval"); } - scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), - "Fusion must contain a single expression of type MatmulOp or LinearOp"); return false; } diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 3cd23a1f789..4f116b32db5 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -205,9 +205,9 @@ std::string isMatmulFusionDefinitionSupported( constexpr size_t minimal_number_of_inputs = 2; - // Quick checks - MmaOp + // Quick checks { - // Check if MmaOp represents gemm (requires M/N/K == 1, B == 0) + // Check if matmul pattern represents gemm (requires M/N/K == 1, B == 0) // or bgemm (requires M/N/K/B == 1) std::array num_axes{}; for (const auto& [g, dom] : id_roles) { @@ -218,7 +218,7 @@ std::string isMatmulFusionDefinitionSupported( num_axes[(size_t)MatmulDomain::N] != expected_axes_numbers || num_axes[(size_t)MatmulDomain::K] != expected_axes_numbers || num_axes[(size_t)MatmulDomain::Batch] > expected_axes_numbers) { - return "MmaOp has unsupported number of one of M/N/K/Batch axes"; + return "Matmul pattern has unsupported number of one of M/N/K/Batch axes"; } if (!mma_output->hasReduction()) { @@ -236,31 +236,47 @@ std::string isMatmulFusionDefinitionSupported( // Fusion topology check { - auto entry = roles_map.find(MatmulRole::INPUT_A); + // We will check that all operands have same dimension + int64_t operand_dim = -1; + + // Track TensorViews with assigned roles so we can check that all inputs and + // outputs have recognized roles std::set tvs_with_roles; - if (entry != roles_map.end()) { - if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { - tvs_with_roles.insert(entry->second.begin(), entry->second.end()); + for (MatmulRole role : {MatmulRole::INPUT_A, MatmulRole::INPUT_B}) { + auto entry = roles_map.find(role); + if (entry != roles_map.end()) { + if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { + tvs_with_roles.insert(entry->second.begin(), entry->second.end()); + for (TensorView* tv : entry->second) { + const std::vector& leaf = tv->getLeafDomain(); + int64_t ndims = (int64_t)std::count_if( + leaf.begin(), leaf.end(), [](IterDomain* id) { + return !id->isReduction() && !id->isDeviceDim(); + }); + if (operand_dim == -1) { + operand_dim = ndims; + } else if (ndims != operand_dim) { + // We cannot always handle differently sized inputs, such as those + // we encounter when translating MatmulOp and LinearOp. This is + // because in those cases one of the operands will have new + // Broadcast dimensions where the other operand has Iteration + // batch dimensions, meaning these new dims are actually M or N + // dimensions. Multiple M and N dimension support is planned but + // for now we must reject these patterns before attempting to + // translate them. + return "All operands must have the same no-devices dimension."; + } + } + } else { + return "There is more than a single fusion input that can be MMA operand "; + } } else { - return "There is more than a single fusion input that can be MMA first input"; + return "No candidate in fusion inputs for MMA operand"; } - } else { - return "No candidate in fusion inputs for MMA first input"; } - entry = roles_map.find(MatmulRole::INPUT_B); - if (entry != roles_map.end()) { - if (MATMUL_CORE_ROLES_EXPECTED_COUNT == entry->second.size()) { - tvs_with_roles.insert(entry->second.begin(), entry->second.end()); - } else { - return "There is more than a single fusion input that can be MMA second input"; - } - } else { - return "No candidate in fusion inputs for MMA second input"; - } - - entry = roles_map.find(MatmulRole::OUTPUT_D); + auto entry = roles_map.find(MatmulRole::OUTPUT_D); if (entry != roles_map.end()) { tvs_with_roles.insert(entry->second.begin(), entry->second.end()); } else { @@ -458,8 +474,11 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // The plan: // 0. Check if the current CUDA device is supported // 1. Check if there is exactly one matmul pattern defined in the fusion. - // 2. Check if the input layout for the matmul pattern can be determined - // 3. Check if fusion represents expressions that are recognized by matmul + // 2. Check if fusion of MatmulOp and LinearOp is enabled, if applicable + // 3. Check if inputs to the matmul pattern match any of + // supported inputs layout + // 4. Check if fusion represents expressions that are recognized by matmul + // 5. Check if the input layout for the matmul pattern can be determined // scheduler. // #0 @@ -475,13 +494,39 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { } // #1 - // Initializing the machinery to check if there's a Mul-Sum pair - // can be replaced by a Mma Op. + // Find matmul patterns std::vector patterns = mma_utils::findMatmulPatterns(fusion); if (patterns.empty()) { return "No matmul patterns were found"; } + + // #2 + { + for (const mma_utils::MatmulPattern& pattern : patterns) { + Expr* op = pattern.output->definition(); + if (op->isA() || op->isA()) { + if (!isOptionEnabled(EnableOption::FuseMatmul)) { + // Check for MatmulOp or LinearOp. If found, then only fuse if option + // is specified + return "MatmulOp and LinearOp fusion is disabled by default. " + "Enable it using NVFUSER_ENABLE=fuse_matmul"; + } + // Refuse patterns containing 1D inputs since these are mat-vec as + // opposed to mat-mat products. + if (pattern.A->nDims() < 2 || pattern.B->nDims() < 2) { + return "Cannot fuse matrix-vector products"; + } + for (TensorView* operand : {pattern.A, pattern.B}) { + if (operand->dtype() != DataType::Half && + operand->dtype() != DataType::BFloat16) { + return "Unsupported operand type. Operands must be fp16 or bf16"; + } + } + } + } + } + if (patterns.size() > 1) { return "Only a single matmul pattern can currently be fused"; } @@ -498,13 +543,6 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { mma_utils::RolesMap roles_map = roles_map_opt.getData(); // #4 - const auto input_layout_opt = - mma_utils::getProblemLayout(id_model, id_roles, roles_map); - if (!input_layout_opt.isValid()) { - return input_layout_opt.getErrorMsg(); - } - - // #5 { auto support_status = isMatmulFusionDefinitionSupported( fusion, patterns.front(), roles_map, id_roles); @@ -513,6 +551,13 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { } } + // #5 + const auto input_layout_opt = + mma_utils::getProblemLayout(id_model, id_roles, roles_map); + if (!input_layout_opt.isValid()) { + return input_layout_opt.getErrorMsg(); + } + return ""; } diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 44c6e465f77..d3555adb08b 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -1375,6 +1376,24 @@ class MatmulPatternMatcher : IterVisitor { // there is a transpose operation in the epilogue, then this assumption will // be violated. In such cases we should actually swap and transpose A and B. + // Match all LinearOps and MatmulOps as MatmulPatterns. This includes ops + // whose inputs are not 2D, i.e. matrix-vector products. The matmul scheduler + // will decide whether or not it can fuse a given pattern based on the + // dimensionality of its inputs. + void handle(LinearOp* lop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = lop->inA()->as(); + pattern.B = lop->inB()->as(); + pattern.output = lop->out()->as(); + } + + void handle(MatmulOp* mop) override { + MatmulPattern& pattern = patterns_.emplace_back(); + pattern.A = mop->inA()->as(); + pattern.B = mop->inB()->as(); + pattern.output = mop->out()->as(); + } + // Handle the case when no translation is needed. void handle(MmaOp* mop) override { MatmulPattern& pattern = patterns_.emplace_back(); @@ -1508,12 +1527,128 @@ MmaOp* MatmulPattern::translateToMmaOp() { // This replaces the mul and sum by overwriting output->definition() return IrBuilder::create(output, A, B, init); } - NVF_ERROR( - false, - "Could not translate matmul pattern with output ", - output->toString(), - " to MmaOp"); + + // This will hold the translated output from MatmulOp or LinearOp + TensorView* fms = nullptr; + MmaOp* mma_op = nullptr; + if (auto lop = dynamic_cast(output->definition())) { + // Linear takes inputs input, weight(, bias) + // - input can be any dimension > 0. We assert that it must be at least 2 + // and refuse to translate if dimension is 1. + // - weight can be one or two dimensional. We refuse to translate if + // dimension is 1. + // - bias, if present, can be zero or one dimensional. Bias can only be + // present if weight is 2D + // + // We translate by broadcasting input, weight, and bias such that the + // contracted dimension K is in the last position (this is true of the + // rfactor domains in input and weight already). Then we form an MmaOp and + // optionally add the bias tensor followed by a cast back to the input + // dtype. + NVF_ERROR( + A->nDims() > 1 && B->nDims() > 1, + "Cannot translate LinearOp with 1D input"); + std::vector bcast_dim((size_t)A->nDims() + 1, false); + bcast_dim[bcast_dim.size() - 2] = true; // N + A = broadcast(A, bcast_dim); + + bcast_dim[bcast_dim.size() - 2] = false; // reset N + std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true); + B = broadcast(B, bcast_dim); + + fms = fusedMultiplySum(A, B, {-1}); + mma_op = fms->definition()->as(); + + auto* bias = dynamic_cast(lop->bias()); + if (bias != nullptr) { + fms = add(fms, bias); + } + } else if (output->definition()->isA()) { + // MatmulOp takes inputs whose sizes are [..., M, K] and [..., K, N], so we + // must transpose B then broadcast both operands before creating the final + // op. + // + // Also note that the output of MatmulOp is a tensor of shape [..., M, N] + // whose dtype matches that of the inputs. We will most commonly then also + // need to cast the output of the MmaOp to produce the output TensorView. + NVF_ERROR( + A->nDims() > 1 && B->nDims() > 1, + "Cannot translate MatmulOp with 1D input"); + TensorView* Btrans = transpose(B, -2, -1); + A = unsqueeze(A, -2); + B = unsqueeze(Btrans, -3); + // A and B might have different dimensions. If so, broadcast the smaller one + // up to the size of the larger. + int64_t out_dims = std::max(A->nDims(), B->nDims()); + // Add new outer broadcast dimensions if necessary + A = ops::maybe_broadcast_inner_to_rank(A, out_dims); + B = ops::maybe_broadcast_inner_to_rank(B, out_dims); + fms = fusedMultiplySum(A, B, {-1}); + mma_op = fms->definition()->as(); + } else { + NVF_ERROR( + false, + "Could not translate matmul pattern with output ", + output->toString(), + " to MmaOp"); + } + NVF_ERROR(fms != nullptr); + NVF_ERROR(mma_op != nullptr); + + // The following is common to both MatmulOp and LinearOp translation + + // TODO: skip downcasting if the only uses of `output` are casts back to + // higher precision in order avoid the round trip cast in defining an + // epilogue that starts with MatmulOp. + if (output->dtype() != fms->dtype()) { + // Redefine output as cast of MmaOp->out() + IrBuilder::create(UnaryOpType::Cast, output, fms); + // Update output so that cast is part of the epilogue + output = fms; + } else { + // No cast needed, for example the inputs might be Float + ir_utils::transferDefinitionToNewOutputs(fms->definition(), {output}); + } + return mma_op; +} + +namespace { +// Determine dim roles for either a MatmulOp or a LinearOp, given IterDomain +// mappings +std::unordered_map matmulOrLinearOpDimRoles( + const ValGraph& exact_graph, + const std::vector& out_logical, + const std::vector& mapping_a, + const std::vector& mapping_b) { + std::unordered_map dim_roles; + NVF_ERROR(mapping_a.size() == out_logical.size()); + NVF_ERROR(mapping_a.size() == mapping_b.size()); + for (size_t i : c10::irange(out_logical.size())) { + IterDomain* id_out = out_logical[i]; + const ValGroup& g = exact_graph.toGroup(id_out); + + if (id_out->isReduction()) { + dim_roles[g] = MatmulDomain::K; + continue; + } + + bool has_a = mapping_a[i] != nullptr && mapping_a[i]->isIteration(); + bool has_b = mapping_b[i] != nullptr && mapping_b[i]->isIteration(); + + NVF_ERROR(has_a || has_b); + // If both operand IterDomains are Broadcast, treat as Batch dimension + // If they mismatch, then one must be broadcast which determines M or N + if (has_a == has_b) { + dim_roles[g] = MatmulDomain::Batch; + } else if (has_a) { + dim_roles[g] = MatmulDomain::M; + } else if (has_b) { + dim_roles[g] = MatmulDomain::N; + } + } + return dim_roles; } +} // namespace std::unordered_map MatmulPattern::getDimRoles( IdModel& id_model) const { @@ -1530,6 +1665,31 @@ std::unordered_map MatmulPattern::getDimRoles( // If there are other patterns, for example a ValGroup present in only A, then // we should raise an exception here. + if (output->definition()->isA()) { + const std::vector& out_logical = output->getRFactorDomain(); + return matmulOrLinearOpDimRoles( + exact_graph, + out_logical, + ops::mapMatmulOpIterDomains( + A->getRFactorDomain(), MatmulRole::INPUT_A, out_logical.size()), + ops::mapMatmulOpIterDomains( + B->getRFactorDomain(), MatmulRole::INPUT_B, out_logical.size())); + + } else if (output->definition()->isA()) { + const std::vector& out_logical = output->getRFactorDomain(); + return matmulOrLinearOpDimRoles( + exact_graph, + out_logical, + ops::mapLinearOpIterDomains( + A->getRFactorDomain(), MatmulRole::INPUT_A, out_logical.size()), + ops::mapLinearOpIterDomains( + B->getRFactorDomain(), MatmulRole::INPUT_B, out_logical.size())); + } + + // The code below handles MmaOp or mul-sum patterns + + std::unordered_map dim_roles; + // Indicates whether a ValGroup is present in A (bit 0), B (bit 1), or output // (bit 2) using DimPresence = std::bitset<3>; @@ -1551,15 +1711,14 @@ std::unordered_map MatmulPattern::getDimRoles( recordPresence(B, 1); recordPresence(output, 2); - std::unordered_map dim_roles; for (const auto& [g, flags] : present_flags) { if (flags.all()) { dim_roles[g] = MatmulDomain::Batch; } else if (flags.test(0) && flags.test(1)) { dim_roles[g] = MatmulDomain::K; - } else if (flags.test(0) && flags.test(2)) { + } else if (flags.test(0) && !flags.test(1) && flags.test(2)) { dim_roles[g] = MatmulDomain::M; - } else if (flags.test(1) && flags.test(2)) { + } else if (!flags.test(0) && flags.test(1) && flags.test(2)) { dim_roles[g] = MatmulDomain::N; } else { NVF_ERROR( diff --git a/csrc/scheduler/no_op.cpp b/csrc/scheduler/no_op.cpp index 01eda65e043..de48d1adb01 100644 --- a/csrc/scheduler/no_op.cpp +++ b/csrc/scheduler/no_op.cpp @@ -50,6 +50,12 @@ bool NoOpScheduler::canScheduleCompileTime(Fusion* fusion) { return true; } + if (ir_utils::hasAnyMatmulOps(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), "matmul ops are not supported"); + return false; + } + // Check there're no non-trivial reduction ops. for (auto reduction : ir_utils::getAllTypesOfReductionOps(fusion)) { for (auto output : diff --git a/tests/cpp/test_combine_mul_sum.cpp b/tests/cpp/test_combine_mul_sum.cpp index 53a46764ad3..91d956f952b 100644 --- a/tests/cpp/test_combine_mul_sum.cpp +++ b/tests/cpp/test_combine_mul_sum.cpp @@ -27,6 +27,8 @@ #include #include #include +#include +#include #include #include #include @@ -269,6 +271,351 @@ TEST_F(CombineMulSumAsMmaTest, UseMatmulScheduler) { } } +using MatmulNodeTranslationTestParams = + std::tuple; +using MatmulNodeTranslationTest = + NVFuserFixtureParamTest; + +// Test that a simple matmul op fusion is picked up by the appropriate scheduler +// and the translation to MmaOp is performed properly. +TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + const int64_t A_dim = std::get<0>(GetParam()); + const int64_t B_dim = std::get<1>(GetParam()); + const bool enable_fusion = std::get<2>(GetParam()); + const bool transpose_a_alloc = std::get<3>(GetParam()); + const bool expect_segmented = std::get<4>(GetParam()); + const ScheduleHeuristic expected_heuristic = std::get<5>(GetParam()); + + // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); + + EnableOptionsGuard eog; + if (enable_fusion) { + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); + } + + // The allocation domain propagation pass sets the output allocation domain, + // which sometimes causes the matmul scheduler to decline the whole fusion + // when it could compile it otherwise. + preseg_passes::OptimizationPassGuard + alloc_pass_guard(false); + + int batch_size = 3, M = 504, N = 136, K = 248; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(A_dim, DataType::Half); + auto tv1 = makeContigTensor(B_dim, DataType::Half); + + if (transpose_a_alloc && A_dim > 1) { + std::vector alloc = tv0->getMaybeAllocationDomain(); + alloc[alloc.size() - 1] = tv0->axis(-2); + alloc[alloc.size() - 2] = tv0->axis(-1); + tv0->setAllocationDomain(alloc, true); + } + + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = matmul(tv0, tv1); + + // add an epilogue + auto tv3 = sin(tv2); + auto tv4 = castOp(DataType::Half, tv3); + + fusion->addOutput(tv4); + + // Verify that we no longer set up MmaOp in matmul() + ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + std::vector A_shape(A_dim, batch_size); + A_shape[A_dim - 1] = K; + if (A_dim > 1) { + A_shape[A_dim - 2] = M; + } + at::Tensor t0 = at::randn(A_shape, options); + std::vector B_shape(B_dim, batch_size); + if (B_dim > 1) { + B_shape[B_dim - 2] = K; + B_shape[B_dim - 1] = N; + } else { + B_shape[B_dim - 1] = K; + } + auto t1 = at::randn(B_shape, options); + if (transpose_a_alloc) { + t0 = t0.as_strided({M, K}, {1, M}); + } + auto tref = at::matmul(t0, t1).sin().to(at::kHalf); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + ASSERT_NE(runtime, nullptr); + + if (expect_segmented) { + EXPECT_TRUE(runtime->isSegmented()); + } else { + EXPECT_FALSE(runtime->isSegmented()); + } + + ScheduleHeuristic heuristic = + runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); + EXPECT_EQ(heuristic, expected_heuristic); + + if (heuristic == ScheduleHeuristic::Matmul) { + // Ensure there's an MmaOp. + EXPECT_FALSE( + ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) + .empty()); + } + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + , + MatmulNodeTranslationTest, + ::testing::Values( + // Tests without fusion enabled + std:: + make_tuple(2l, 2l, false, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(2l, 2l, false, true, true, ScheduleHeuristic::ExprEval), + // Tests with fusion enabled + std::make_tuple(2l, 2l, true, false, false, ScheduleHeuristic::Matmul), + // We cannot yet handle allocation domain in matmul scheduler + std::make_tuple(2l, 2l, true, true, true, ScheduleHeuristic::ExprEval), + // Size-1 input combinations + std::make_tuple(1l, 2l, true, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(2l, 1l, true, false, true, ScheduleHeuristic::ExprEval), + // We fuse this case using the Reduction scheduler + std::make_tuple( + 1l, + 1l, + true, + false, + false, + ScheduleHeuristic::Reduction), + // Batch dims + std::make_tuple(3l, 1l, true, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(3l, 3l, true, false, false, ScheduleHeuristic::Matmul), + // TODO: mixed length inputs via broadcasted batch dims + // We currently reject differently-sized inputs since these translate to + // multiple M or N dims + std::make_tuple(3l, 2l, true, false, true, ScheduleHeuristic::ExprEval), + std::make_tuple(2l, 3l, true, false, true, ScheduleHeuristic::ExprEval), + // TODO: More than one batch dimension is not yet supported in Matmul + // scheduler + std:: + make_tuple(4l, 4l, true, false, true, ScheduleHeuristic::ExprEval)), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "dA"; + os << "_" << std::get<1>(info.param) << "dB"; + if (!std::get<2>(info.param)) { + os << "_nofuse"; + } + if (std::get<3>(info.param)) { + os << "_transposeA"; + } + return os.str(); + }); + +using LinearNodeTranslationTestParams = + std::tuple; +using LinearNodeTranslationTest = + NVFuserFixtureParamTest; + +// Test that a simple linear op fusion is picked up by the appropriate scheduler +// and the translation to MmaOp is performed properly. +TEST_P(LinearNodeTranslationTest, AutomaticSchedulerLinearNode) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + // The allocation domain propagation pass sets the output allocation domain, + // which sometimes causes the matmul scheduler to decline the whole fusion + // when it could compile it otherwise. + preseg_passes::OptimizationPassGuard + alloc_pass_guard(false); + const int64_t A_dim = std::get<0>(GetParam()); + const int64_t B_dim = std::get<1>(GetParam()); + const int64_t bias_dim = std::get<2>(GetParam()); + const bool enable_fusion = std::get<3>(GetParam()); + const bool transpose_a_alloc = std::get<4>(GetParam()); + const bool expect_aten_eval = std::get<5>(GetParam()); + + // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it + // enabled + DisableOptionsGuard dog; + DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval); + + EnableOptionsGuard eog; + if (enable_fusion) { + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul); + } + + int batch_size = 3, M = 504, N = 136, K = 248; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(A_dim, DataType::Half); + auto tv1 = makeContigTensor(B_dim, DataType::Half); + + if (transpose_a_alloc && A_dim > 1) { + std::vector alloc = tv0->getMaybeAllocationDomain(); + alloc[alloc.size() - 1] = tv0->axis(-2); + alloc[alloc.size() - 2] = tv0->axis(-1); + tv0->setAllocationDomain(alloc, true); + } + + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = nullptr; + if (bias_dim >= 0) { + // bias_dim = -1 indicates we should not use any bias argument + auto bias = makeContigTensor(bias_dim, DataType::Half); + fusion->addInput(bias); + tv2 = linear(tv0, tv1, bias); + } else { + tv2 = linear(tv0, tv1); + } + + // add an epilogue + auto tv3 = sin(tv2); + auto tv4 = castOp(DataType::Half, tv3); + + fusion->addOutput(tv4); + + // Verify that we no longer set up MmaOp in matmul() + ASSERT_TRUE(ir_utils::getOpsOfType(fusion.get()).empty()); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + std::vector A_shape(A_dim, batch_size); + A_shape[A_dim - 1] = K; + if (A_dim > 1) { + A_shape[A_dim - 2] = M; + } + at::Tensor t0 = at::randn(A_shape, options); + std::vector B_shape(B_dim, batch_size); + B_shape[B_dim - 1] = K; + if (B_dim > 1) { + B_shape[B_dim - 2] = N; + } + auto t1 = at::randn(B_shape, options); + if (transpose_a_alloc) { + t0 = t0.as_strided({M, K}, {1, M}); + } + std::vector inputs{t0, t1}; + at::Tensor tref; + if (bias_dim >= 0) { + at::Tensor bias; + if (bias_dim == 0) { + bias = at::randn({}, options); + } else if (bias_dim == 1) { + bias = at::randn({N}, options); + } else { + NVF_ERROR(false, "Invalid bias dimension given:", bias_dim); + } + inputs.emplace_back(bias); + tref = at::linear(t0, t1, bias); + } else { + tref = at::linear(t0, t1); + } + tref = tref.sin().to(at::kHalf); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + ASSERT_NE(runtime, nullptr); + + if (expect_aten_eval) { + EXPECT_TRUE(runtime->isSegmented()); + } else { + EXPECT_FALSE(runtime->isSegmented()); + } + + ScheduleHeuristic heuristic = + runtime->schedulerHeuristics()->heuristicsList().front()->heuristic(); + if (expect_aten_eval) { + EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); + } else { + // Ensure that the Matmul scheduler ran. + // Assert here since we will inspect the kernel next, which we can't + // do if ExprEval accepts the segment. + ASSERT_EQ(heuristic, ScheduleHeuristic::Matmul); + // Ensure there's an MmaOp. + EXPECT_FALSE( + ir_utils::getOpsOfType(runtime->executors().at(0).kernel()) + .empty()); + } + + testValidate( + executor_cache.fusion(), outputs, inputs, {tref}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + , + LinearNodeTranslationTest, + ::testing::Values( + // Tests without fusion enabled + std::make_tuple(2l, 2l, -1l, false, false, true), + std::make_tuple(2l, 2l, -1l, false, true, true), + std::make_tuple(1l, 2l, -1l, false, false, true), + std::make_tuple(2l, 1l, -1l, false, false, true), + std::make_tuple(2l, 2l, 1l, false, false, true), + std::make_tuple(1l, 1l, -1l, false, false, true), + std::make_tuple(3l, 2l, 1l, false, false, true), + std::make_tuple(4l, 2l, 1l, false, false, true), + // Enable fusion + std::make_tuple(2l, 2l, -1l, true, false, false), + // We cannot yet handle allocation domain in matmul scheduler + std::make_tuple(2l, 2l, -1l, true, true, true), + // We don't fuse 1D inputs + std::make_tuple(1l, 2l, -1l, true, false, true), + std::make_tuple(2l, 1l, -1l, true, false, true), + // Check that zero-dim output fusion is not claimed by NoOp scheduler + std::make_tuple(1l, 1l, -1l, true, false, true), + // Batch dims in input + // TODO: mixed length inputs via broadcasted batch dims + // We currently reject differently-sized inputs since these translate to + // multiple M or N dims + std::make_tuple(3l, 2l, -1l, true, false, true), + // TODO: We don't yet support multiple batch dims in matmul scheduler + std::make_tuple(4l, 2l, -1l, true, false, true), + // Bias cases + std::make_tuple(2l, 2l, 0l, true, false, false), + std::make_tuple(2l, 2l, 1l, true, false, false), + // TODO: Mixed-length inputs are rejected with bias also + std::make_tuple(3l, 2l, 1l, true, false, true), + // TODO: We don't yet support multiple batch dims in matmul scheduler + std::make_tuple(4l, 2l, 1l, true, false, true)), + [](const testing::TestParamInfo& info) { + std::ostringstream os; + os << std::get<0>(info.param) << "dA"; + os << "_" << std::get<1>(info.param) << "dB"; + int64_t bias_dim = std::get<2>(info.param); + if (bias_dim >= 0) { + os << "_" << bias_dim << "dBias"; + } + if (!std::get<3>(info.param)) { + os << "_nofuse"; + } + if (std::get<4>(info.param)) { + os << "_transposeA"; + } + return os.str(); + }); + // Check that we determine A and B properly when they are swapped as inputs to // mul TEST_F(CombineMulSumAsMmaTest, SwapAandB) { diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 5ffb26d40e9..2b2eeccfade 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2604,12 +2604,11 @@ TEST_F(MatmulSchedulerPluginTest, BasicMatmul) { // TODO: Once we can control the ExprEval and Matmul schedulers via options, run // this test with all three combinations (with and without each scheduler, but // at least one enabled). -TEST_F(NVFuserTest, SegmentMatmulOpPrologue) { +TEST_F(MatmulSchedulerTest, SegmentMatmulOpPrologue) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); fusion->addInput(tv0); @@ -2625,7 +2624,7 @@ TEST_F(NVFuserTest, SegmentMatmulOpPrologue) { NVF_CHECK( ir_utils::getOpsOfType(fusion.get()).size() == 1, - "matmul fusion must have at least one MmaOp"); + "matmul fusion must have at least one MatmulOp"); FusionExecutorCache executor_cache(std::move(fusion)); @@ -2645,12 +2644,11 @@ TEST_F(NVFuserTest, SegmentMatmulOpPrologue) { } // This is just like the above test but with LinearOp instead of MatmulOp -TEST_F(NVFuserTest, SegmentLinearOpPrologue) { +TEST_F(MatmulSchedulerTest, SegmentLinearOpPrologue) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - // A - tv0, B - tv1, C - tv2 auto tv0 = makeContigTensor(2, DataType::Half); auto tv1 = makeContigTensor(2, DataType::Half); fusion->addInput(tv0); @@ -2685,6 +2683,121 @@ TEST_F(NVFuserTest, SegmentLinearOpPrologue) { testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); } +// Test that the matmul scheduler refuses to translate a matmul that is not +// Half or BFloat16 +TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(2, DataType::Float); + auto tv1 = makeContigTensor(2, DataType::Float); + fusion->addInput(tv0); + fusion->addInput(tv1); + + // Prologue prevents ExprEval scheduler from accepting. If Matmul scheduler + // rejects, then Pointwise must not accept this unsegmented fusion. + tv1 = castOp(DataType::Float, sin(tv1)); + + auto tv2 = matmul(tv0, tv1); + + fusion->addOutput(tv2); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + + // Enable MatmulOp fusion, which should reject because float operands are not + // supported. + EnableOptionsGuard eog; + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + + EXPECT_TRUE(runtime->isSegmented()); + + testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); +} + +class MatmulFusionTest : public MatmulSchedulerTest, + public ::testing::WithParamInterface { + protected: + void SetUp() override { + if (fusion_enabled) { + EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul); + } + } + + EnableOptionsGuard eog_; + bool fusion_enabled = GetParam(); +}; + +// Test that we can segment a Fusion containing two matmuls +TEST_P(MatmulFusionTest, Llama2FFN) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + auto tv2 = makeContigTensor(2, DataType::Half); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + auto tv3 = matmul(tv0, tv1); + auto tv4 = matmul(tv0, tv2); + + // silu + auto tv5 = mul(sigmoid(tv3), tv3); + + auto tv6 = mul(tv5, tv4); + + fusion->addOutput(tv6); + + FusionExecutorCache executor_cache(std::move(fusion)); + + const int M = 504, N = 136, K = 248; + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({K, N}, options); + auto t2 = at::randn({K, N}, options); + std::vector inputs{t0, t1, t2}; + + auto outputs = executor_cache.runFusionWithInputs(inputs); + + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + const FusionKernelRuntime* runtime = + executor_cache.getMostRecentKernelRuntime(); + + EXPECT_TRUE(runtime->isSegmented()); + + if (fusion_enabled) { + EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2); + } else { + EXPECT_EQ(runtime->fusionSegments()->groups().size(), 3); + } +} + +INSTANTIATE_TEST_SUITE_P( + , + MatmulFusionTest, + ::testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "fuse" : "dontfuse"; + }); + // This test can be used to check that an external plugin has been loaded. It // is DISABLED_ so that the test suite will pass even if the user has not // provided a plugin via NVFUSER_MATMUL_HEURISTIC_PLUGIN. To check that a