From a695299ca0ef06473dfc229ffe738cea3a85b1c1 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 25 Apr 2024 22:45:28 +0000 Subject: [PATCH 01/35] add matmul node, create output tensorview --- csrc/ops/arith.cpp | 69 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index b7f0f7235a3..f353f48624f 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2684,4 +2684,73 @@ TensorView* tensor(Val* val) { return out; } +namespace { + +//! Create new output for matmul +static TensorView* newForMatmul( + TensorView* tv_a, + TensorView* tv_b +) { + auto orig_domain_a = tv_a->getMaybeRFactorDomain(); + auto orig_domain_b = tv_b->getMaybeRFactorDomain(); + auto ndims_a = orig_domain_a.size(); + auto ndims_b = orig_domain_b.size(); + + NVF_ERROR(ndims_a>=1 && ndims_b>= 1); + + std::vector new_domain; + + if (ndims_a > 2 || ndims_b > 2) { + auto higher_dim_domain = ndims_a >= ndims_b ? orig_domain_a : orig_domain_b; + auto lower_dim_domain = ndims_a >= ndims_b ? orig_domain_b : orig_domain_a; + auto higher_batch_ndims = higher_dim_domain.size() - 2; + auto lower_batch_ndims = lower_dim_domain.size() - 2; + + + auto batch_ndims = std::abs(higher_batch_ndims - 2); + auto non_common_batch_ndims = lower_batch_ndims > 0 ? higher_batch_ndims - lower_batch_ndims : higher_batch_ndims; + + // Add the first abs(ndims_a - ndims_b - 2) to the new domain + for (auto inx: c10::irange(batch_ndims)) { + if (inx < non_common_batch_ndims) { + new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx])); + } + // Check for common extents here + if (higher_dim_domain[inx]->extent() != 1){ + NVF_ERROR(higher_dim_domain[inx]->extent() == lower_dim_domain[inx - non_common_batch_ndims]->extent()); + new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx])); + } else { + new_domain.push_back(IterDomainBuilder(lower_dim_domain[inx - non_common_batch_ndims])); + } + } + } + + // Add M domain to output if present + if (orig_domain_a.size() > 1) { + new_domain.push_back(IterDomainBuilder(orig_domain_a[-2])); + } + + // Add N domain to output if present + if (orig_domain_b.size() > 1) { + new_domain.push_back(IterDomainBuilder(orig_domain_b[-1])); + } + + TensorDomain* td = IrBuilder::create( + new_domain, TensorDomain::getContiguityFilledWith(new_domain, true)); + + return IrBuilder::create(td, *tv_a->getDataType()); +} + +} // namespace + +TensorView* eagerMatmul( + TensorView* tv_a, + TensorView* tv_b) { + + NVF_CHECK(tv_a->getDataType().value() == tv_b->getDataType().value()); + TensorView* out = newForMatmul(tv_a, tv_b); + IrBuilder::create(out, tv_a, tv_b); + return out; +} + } // namespace nvfuser From 961ea76eb757757967e4d2275aab0318df7b72a8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 25 Apr 2024 23:20:09 +0000 Subject: [PATCH 02/35] wip scheduler --- csrc/ops/arith.cpp | 18 ++++---- csrc/scheduler/expr_eval_sched.cpp | 58 ++++++++++++++++++++++++++ csrc/scheduler/expr_eval_sched.h | 66 ++++++++++++++++++++++++++++++ 3 files changed, 132 insertions(+), 10 deletions(-) create mode 100644 csrc/scheduler/expr_eval_sched.cpp create mode 100644 csrc/scheduler/expr_eval_sched.h diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index f353f48624f..3708a770da0 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2706,33 +2706,31 @@ static TensorView* newForMatmul( auto higher_batch_ndims = higher_dim_domain.size() - 2; auto lower_batch_ndims = lower_dim_domain.size() - 2; - - auto batch_ndims = std::abs(higher_batch_ndims - 2); + auto batch_ndims = higher_batch_ndims - 2; auto non_common_batch_ndims = lower_batch_ndims > 0 ? higher_batch_ndims - lower_batch_ndims : higher_batch_ndims; - // Add the first abs(ndims_a - ndims_b - 2) to the new domain for (auto inx: c10::irange(batch_ndims)) { if (inx < non_common_batch_ndims) { - new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx])); + new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx]).resetSchedulingParams().build()); } // Check for common extents here - if (higher_dim_domain[inx]->extent() != 1){ - NVF_ERROR(higher_dim_domain[inx]->extent() == lower_dim_domain[inx - non_common_batch_ndims]->extent()); - new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx])); + if (higher_dim_domain[inx]->extent()->isOneInt()){ + new_domain.push_back(IterDomainBuilder(lower_dim_domain[inx - non_common_batch_ndims]).resetSchedulingParams().build()); } else { - new_domain.push_back(IterDomainBuilder(lower_dim_domain[inx - non_common_batch_ndims])); + NVF_ERROR(higher_dim_domain[inx]->extent() == lower_dim_domain[inx - non_common_batch_ndims]->extent()); + new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx]).resetSchedulingParams().build()); } } } // Add M domain to output if present if (orig_domain_a.size() > 1) { - new_domain.push_back(IterDomainBuilder(orig_domain_a[-2])); + new_domain.push_back(IterDomainBuilder(orig_domain_a[-2]).resetSchedulingParams().build()); } // Add N domain to output if present if (orig_domain_b.size() > 1) { - new_domain.push_back(IterDomainBuilder(orig_domain_b[-1])); + new_domain.push_back(IterDomainBuilder(orig_domain_b[-1]).resetSchedulingParams().build()); } TensorDomain* td = IrBuilder::create( diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp new file mode 100644 index 00000000000..88c4510d397 --- /dev/null +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -0,0 +1,58 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include + +namespace nvfuser { + +template +void vlog(const Args&... args) { + scheduler_debug_utils::log("[Expression Evaluator Scheduler] ", args...); +} + +ExprEvalScheduler::ExprEvalScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) + : SchedulerEntry(heuristicType()) { + params_ = std::make_shared("", runtime_info.getIndexType()); +} + +//! Check if the no-op heuristics apply in given fusion +bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { + // Check if the fusion has matmul node and accept + if (fusion->outputs().size() == 1 && fusion->outputs().front()->isA()) { + return true; + } + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), "Only accepts MatmulOp"); + return false; +} + +bool ExprEvalScheduler::canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + return true; +} + +void ExprEvalScheduler::schedule(Fusion* fusion) { + fusion->aliasOutputToInput( + fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); +} + +void ExprEvalScheduler::computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + return; +} +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h new file mode 100644 index 00000000000..b56610cc686 --- /dev/null +++ b/csrc/scheduler/expr_eval_sched.h @@ -0,0 +1,66 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { + +class Fusion; +class SchedulerRuntimeInfo; +class HeuristicSummary; + +//! ExprEval scheduler represents the case where we allocate outputs directly using EE. No code is generated. +class ExprEvalScheduler : public SchedulerEntry { + public: + explicit ExprEvalScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); + + //! This scheduler only accepts matmul and linear nodes + static bool canScheduleCompileTime(Fusion* fusion); + + static bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); + + constexpr static ScheduleHeuristic heuristicType() { + return ScheduleHeuristic::ExprEval; + } + + void schedule(Fusion* fusion) override; + + private: + void computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); +}; + +//! Provides a dummy heuristic type to ensure +//! unified interface on ExprEval scheduler. +class ExprEvalHeuristic : public HeuristicParams { + public: + using HeuristicParams::HeuristicParams; + + size_t hash() const override { + return 0; + } + std::shared_ptr clone() const override { + return std::make_shared(); + } + bool sameAs(const std::shared_ptr& other) const override { + auto other_casted = std::dynamic_pointer_cast(other); + return other_casted != nullptr && other_casted->cparams == cparams; + }; +}; + +} // namespace nvfuser \ No newline at end of file From fe23404715d5a8ec82c5fd2e79c4c56aaecabcbc Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 26 Apr 2024 18:25:18 +0000 Subject: [PATCH 03/35] add to dispatch, scheduler heuristic, registry --- CMakeLists.txt | 1 + csrc/dispatch.cpp | 7 +++++++ csrc/ops/arith.cpp | 7 ++++--- csrc/ops/arith.h | 4 ++++ csrc/scheduler/all_schedulers.h | 1 + csrc/scheduler/expr_eval_sched.cpp | 4 ++-- csrc/scheduler/heuristic_types.cpp | 2 ++ csrc/scheduler/heuristic_types.h | 9 ++++++--- csrc/scheduler/registry.cpp | 12 ++++++++++++ 9 files changed, 39 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8795f167edc..b573b0498af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,6 +205,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp ${NVFUSER_SRCS_DIR}/scheduler/utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp + ${NVFUSER_SRCS_DIR}/scheduler/expr_eval_sched.cpp ${NVFUSER_SRCS_DIR}/serde/polymorphic_value.cpp ${NVFUSER_SRCS_DIR}/serde/utils.cpp ${NVFUSER_SRCS_DIR}/swizzle.cpp diff --git a/csrc/dispatch.cpp b/csrc/dispatch.cpp index e1353ff772b..b7b26a5d539 100644 --- a/csrc/dispatch.cpp +++ b/csrc/dispatch.cpp @@ -315,6 +315,10 @@ DISPATCH_FOR_ALL_KIR_EXPRS(M) DISPATCH_FOR_ALL_KIR_VALS(M) #undef M +void OptOutConstDispatch::handle(const MatmulOp* stmt) { + unhandled(stmt); +} + void OptOutDispatch::unhandled(Statement*) {} // Vals @@ -335,4 +339,7 @@ DISPATCH_FOR_ALL_KIR_VALS(M) DISPATCH_FOR_ALL_KIR_EXPRS(M) #undef M +void OptOutDispatch::handle(MatmulOp* stmt) { + unhandled(stmt); +} } // namespace nvfuser diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 3708a770da0..b7911ec7ada 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2725,18 +2725,19 @@ static TensorView* newForMatmul( // Add M domain to output if present if (orig_domain_a.size() > 1) { - new_domain.push_back(IterDomainBuilder(orig_domain_a[-2]).resetSchedulingParams().build()); + const IterDomain* m_id = orig_domain_a[ndims_a-2]; + new_domain.push_back(IterDomainBuilder(m_id).resetSchedulingParams().build()); } // Add N domain to output if present if (orig_domain_b.size() > 1) { - new_domain.push_back(IterDomainBuilder(orig_domain_b[-1]).resetSchedulingParams().build()); + new_domain.push_back(IterDomainBuilder(orig_domain_b[ndims_b-1]).resetSchedulingParams().build()); } TensorDomain* td = IrBuilder::create( new_domain, TensorDomain::getContiguityFilledWith(new_domain, true)); - return IrBuilder::create(td, *tv_a->getDataType()); + return IrBuilder::create(td, tv_a->dtype()); } } // namespace diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 17cf78074c3..904444fa530 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -806,4 +806,8 @@ NVF_API TensorView* tensor(const std::vector& vals) { return tensor(IrBuilder::arrayExpr(vals)); } +NVF_API TensorView* eagerMatmul( + TensorView* tv_a, + TensorView* tv_b); + } // namespace nvfuser diff --git a/csrc/scheduler/all_schedulers.h b/csrc/scheduler/all_schedulers.h index 5bdf013c2de..08a33343d6a 100644 --- a/csrc/scheduler/all_schedulers.h +++ b/csrc/scheduler/all_schedulers.h @@ -6,6 +6,7 @@ */ // clang-format on #pragma once +#include #include #include #include diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 88c4510d397..11a7989b6b0 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include namespace nvfuser { @@ -29,7 +29,7 @@ ExprEvalScheduler::ExprEvalScheduler( //! Check if the no-op heuristics apply in given fusion bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { // Check if the fusion has matmul node and accept - if (fusion->outputs().size() == 1 && fusion->outputs().front()->isA()) { + if (fusion->outputs().size() == 1 && fusion->outputs().front()->definition()->isA()) { return true; } scheduler_debug_utils::canScheduleRejectReason( diff --git a/csrc/scheduler/heuristic_types.cpp b/csrc/scheduler/heuristic_types.cpp index c85266685c7..3b1c5a7e97e 100644 --- a/csrc/scheduler/heuristic_types.cpp +++ b/csrc/scheduler/heuristic_types.cpp @@ -29,6 +29,8 @@ std::string toString(ScheduleHeuristic sh) { return "transpose"; case ScheduleHeuristic::Matmul: return "matmul"; + case ScheduleHeuristic::ExprEval: + return "expr_eval"; case ScheduleHeuristic::None: return "none"; default: diff --git a/csrc/scheduler/heuristic_types.h b/csrc/scheduler/heuristic_types.h index e30169141b4..16dbe273e75 100644 --- a/csrc/scheduler/heuristic_types.h +++ b/csrc/scheduler/heuristic_types.h @@ -55,19 +55,22 @@ enum class ScheduleHeuristic { InnerPersistent, InnerOuterPersistent, OuterPersistent, - Transpose + Transpose, + ExprEval }; //! Define a schedule table to loop over all the heuristics in priority order. -constexpr std::array all_heuristics_in_priority_order = { +constexpr std::array all_heuristics_in_priority_order = { ScheduleHeuristic::NoOp, + ScheduleHeuristic::ExprEval, ScheduleHeuristic::Matmul, ScheduleHeuristic::Reduction, ScheduleHeuristic::Transpose, ScheduleHeuristic::PointWise, ScheduleHeuristic::InnerPersistent, ScheduleHeuristic::OuterPersistent, - ScheduleHeuristic::InnerOuterPersistent}; + ScheduleHeuristic::InnerOuterPersistent + }; std::string toString(ScheduleHeuristic sh); diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index a9680ecb2f5..462c8943e0c 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -206,6 +206,8 @@ bool checkCanSchedule( case ScheduleHeuristic::Matmul: return checkCanSchedule( fusion, runtime_info, data_cache); + case ScheduleHeuristic::ExprEval: + return checkCanSchedule(fusion, runtime_info, data_cache); default: NVF_ERROR(false, "unreachable"); return false; @@ -252,6 +254,10 @@ bool checkCanSchedule( scheduler_entry = std::make_unique(fusion, runtime_info, data_cache); break; + case ScheduleHeuristic::ExprEval: + scheduler_entry = + std::make_unique(fusion, runtime_info, data_cache); + break; default: NVF_ERROR(false, "unreachable"); } @@ -342,6 +348,9 @@ HeuristicSummary::HeuristicSummary( NVF_ERROR(canSchedule, "Could not schedule matmul (run time)"); break; } + case ScheduleHeuristic::ExprEval: + ExprEvalScheduler::canScheduleRunTime(fusion, runtime_info, this); + break; default: NVF_ERROR(false, "unknown heuristic"); } @@ -419,6 +428,9 @@ void HeuristicSummary::validate() const { // TODO: add a proper set of checks break; } + case ScheduleHeuristic::ExprEval: { + break; + } default: NVF_ERROR(false, "unknown heuristic"); } From 075a83f9289a2324e7898a76114166426dd6a102 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 29 Apr 2024 19:46:27 +0000 Subject: [PATCH 04/35] root map override --- csrc/ops/arith.cpp | 22 ++++++++++++++-------- csrc/root_domain_map.h | 5 +++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index b7911ec7ada..cae5b083d15 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2691,21 +2691,22 @@ static TensorView* newForMatmul( TensorView* tv_a, TensorView* tv_b ) { - auto orig_domain_a = tv_a->getMaybeRFactorDomain(); - auto orig_domain_b = tv_b->getMaybeRFactorDomain(); + auto orig_domain_a = TensorDomain::noReductions(tv_a->getMaybeRFactorDomain()); + auto orig_domain_b = TensorDomain::noReductions(tv_b->getMaybeRFactorDomain()); auto ndims_a = orig_domain_a.size(); auto ndims_b = orig_domain_b.size(); NVF_ERROR(ndims_a>=1 && ndims_b>= 1); std::vector new_domain; + auto higher_dim_domain = ndims_a >= ndims_b ? orig_domain_a : orig_domain_b; + auto lower_dim_domain = ndims_a >= ndims_b ? orig_domain_b : orig_domain_a; + + new_domain.reserve(higher_dim_domain.size()); if (ndims_a > 2 || ndims_b > 2) { - auto higher_dim_domain = ndims_a >= ndims_b ? orig_domain_a : orig_domain_b; - auto lower_dim_domain = ndims_a >= ndims_b ? orig_domain_b : orig_domain_a; auto higher_batch_ndims = higher_dim_domain.size() - 2; auto lower_batch_ndims = lower_dim_domain.size() - 2; - auto batch_ndims = higher_batch_ndims - 2; auto non_common_batch_ndims = lower_batch_ndims > 0 ? higher_batch_ndims - lower_batch_ndims : higher_batch_ndims; @@ -2723,15 +2724,19 @@ static TensorView* newForMatmul( } } - // Add M domain to output if present + // Add M domain to output domain if present if (orig_domain_a.size() > 1) { const IterDomain* m_id = orig_domain_a[ndims_a-2]; new_domain.push_back(IterDomainBuilder(m_id).resetSchedulingParams().build()); } - // Add N domain to output if present + // const IterDomain* k_id = orig_domain_a[ndims_a-1]; + // new_domain.push_back(IterDomainBuilder(k_id).resetSchedulingParams().iter_type(IterType::Reduction).build()); + + // Add N domain to output domain if present if (orig_domain_b.size() > 1) { - new_domain.push_back(IterDomainBuilder(orig_domain_b[ndims_b-1]).resetSchedulingParams().build()); + const IterDomain* n_id = orig_domain_b[ndims_b-1]; + new_domain.push_back(IterDomainBuilder(n_id).resetSchedulingParams().build()); } TensorDomain* td = IrBuilder::create( @@ -2748,6 +2753,7 @@ TensorView* eagerMatmul( NVF_CHECK(tv_a->getDataType().value() == tv_b->getDataType().value()); TensorView* out = newForMatmul(tv_a, tv_b); + // TensorView* out = newForMma(tv_a, tv_b, {1}); IrBuilder::create(out, tv_a, tv_b); return out; } diff --git a/csrc/root_domain_map.h b/csrc/root_domain_map.h index d39bdaa64d7..68a9cb781c9 100644 --- a/csrc/root_domain_map.h +++ b/csrc/root_domain_map.h @@ -245,6 +245,7 @@ class UnmappableReductionDomains : private IterVisitor { void handle(GroupedReductionOp* op) override; void handle(WelfordOp* op) override; void handle(MmaOp* op) override; + // void handle(MatmulOp* op) override; void handleReductionOutput(TensorView* out_tv); @@ -492,6 +493,10 @@ class ComputeAtRootDomainMapBuilder : private BackwardVisitor { mapPointwiseLikeOp(wop); } + void handle(MatmulOp* wop) override { + mapPointwiseLikeOp(wop); + } + void handle(ShiftOp* op) override { mapPointwiseLikeOp(op); } From a7a15a8e36d0e192db01bd74d36d8ceb450e611f Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 1 May 2024 03:27:25 +0000 Subject: [PATCH 05/35] rebase --- csrc/dispatch.cpp | 8 -------- csrc/root_domain_map.h | 5 ----- 2 files changed, 13 deletions(-) diff --git a/csrc/dispatch.cpp b/csrc/dispatch.cpp index b7b26a5d539..c0df251e0b7 100644 --- a/csrc/dispatch.cpp +++ b/csrc/dispatch.cpp @@ -315,10 +315,6 @@ DISPATCH_FOR_ALL_KIR_EXPRS(M) DISPATCH_FOR_ALL_KIR_VALS(M) #undef M -void OptOutConstDispatch::handle(const MatmulOp* stmt) { - unhandled(stmt); -} - void OptOutDispatch::unhandled(Statement*) {} // Vals @@ -338,8 +334,4 @@ M(assoc_comm::FlattenedAssocCommOp) DISPATCH_FOR_ALL_KIR_VALS(M) DISPATCH_FOR_ALL_KIR_EXPRS(M) #undef M - -void OptOutDispatch::handle(MatmulOp* stmt) { - unhandled(stmt); -} } // namespace nvfuser diff --git a/csrc/root_domain_map.h b/csrc/root_domain_map.h index 68a9cb781c9..d39bdaa64d7 100644 --- a/csrc/root_domain_map.h +++ b/csrc/root_domain_map.h @@ -245,7 +245,6 @@ class UnmappableReductionDomains : private IterVisitor { void handle(GroupedReductionOp* op) override; void handle(WelfordOp* op) override; void handle(MmaOp* op) override; - // void handle(MatmulOp* op) override; void handleReductionOutput(TensorView* out_tv); @@ -493,10 +492,6 @@ class ComputeAtRootDomainMapBuilder : private BackwardVisitor { mapPointwiseLikeOp(wop); } - void handle(MatmulOp* wop) override { - mapPointwiseLikeOp(wop); - } - void handle(ShiftOp* op) override { mapPointwiseLikeOp(op); } From 15b90625ff5d467f7fc2bca5de745d4ee00e8c05 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 1 May 2024 05:35:37 +0000 Subject: [PATCH 06/35] mapping for matmul ir node --- csrc/ops/arith.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index cae5b083d15..d5e4f8e62e4 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2705,17 +2705,15 @@ static TensorView* newForMatmul( new_domain.reserve(higher_dim_domain.size()); if (ndims_a > 2 || ndims_b > 2) { - auto higher_batch_ndims = higher_dim_domain.size() - 2; - auto lower_batch_ndims = lower_dim_domain.size() - 2; + auto higher_batch_ndims = higher_dim_domain.size(); + auto lower_batch_ndims = lower_dim_domain.size(); auto batch_ndims = higher_batch_ndims - 2; - auto non_common_batch_ndims = lower_batch_ndims > 0 ? higher_batch_ndims - lower_batch_ndims : higher_batch_ndims; + auto non_common_batch_ndims = lower_batch_ndims > 2 ? higher_batch_ndims - lower_batch_ndims : batch_ndims; for (auto inx: c10::irange(batch_ndims)) { if (inx < non_common_batch_ndims) { new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx]).resetSchedulingParams().build()); - } - // Check for common extents here - if (higher_dim_domain[inx]->extent()->isOneInt()){ + } else if (higher_dim_domain[inx]->extent()->isOneInt()){ new_domain.push_back(IterDomainBuilder(lower_dim_domain[inx - non_common_batch_ndims]).resetSchedulingParams().build()); } else { NVF_ERROR(higher_dim_domain[inx]->extent() == lower_dim_domain[inx - non_common_batch_ndims]->extent()); From 44d13c46762176665db7117d0c162e6dddc15cb2 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 1 May 2024 21:51:18 +0000 Subject: [PATCH 07/35] use higher dim inp to create output --- csrc/ops/arith.cpp | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index d5e4f8e62e4..5bdd09fac9e 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2699,26 +2699,12 @@ static TensorView* newForMatmul( NVF_ERROR(ndims_a>=1 && ndims_b>= 1); std::vector new_domain; - auto higher_dim_domain = ndims_a >= ndims_b ? orig_domain_a : orig_domain_b; - auto lower_dim_domain = ndims_a >= ndims_b ? orig_domain_b : orig_domain_a; - + auto higher_dim_domain = ndims_a >= ndims_b ? orig_domain_a : orig_domain_b; new_domain.reserve(higher_dim_domain.size()); if (ndims_a > 2 || ndims_b > 2) { - auto higher_batch_ndims = higher_dim_domain.size(); - auto lower_batch_ndims = lower_dim_domain.size(); - auto batch_ndims = higher_batch_ndims - 2; - auto non_common_batch_ndims = lower_batch_ndims > 2 ? higher_batch_ndims - lower_batch_ndims : batch_ndims; - - for (auto inx: c10::irange(batch_ndims)) { - if (inx < non_common_batch_ndims) { - new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx]).resetSchedulingParams().build()); - } else if (higher_dim_domain[inx]->extent()->isOneInt()){ - new_domain.push_back(IterDomainBuilder(lower_dim_domain[inx - non_common_batch_ndims]).resetSchedulingParams().build()); - } else { - NVF_ERROR(higher_dim_domain[inx]->extent() == lower_dim_domain[inx - non_common_batch_ndims]->extent()); - new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx]).resetSchedulingParams().build()); - } + for (auto inx: c10::irange(higher_dim_domain.size() - 2)) { + new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx]).resetSchedulingParams().build()); } } @@ -2728,9 +2714,6 @@ static TensorView* newForMatmul( new_domain.push_back(IterDomainBuilder(m_id).resetSchedulingParams().build()); } - // const IterDomain* k_id = orig_domain_a[ndims_a-1]; - // new_domain.push_back(IterDomainBuilder(k_id).resetSchedulingParams().iter_type(IterType::Reduction).build()); - // Add N domain to output domain if present if (orig_domain_b.size() > 1) { const IterDomain* n_id = orig_domain_b[ndims_b-1]; @@ -2748,10 +2731,8 @@ static TensorView* newForMatmul( TensorView* eagerMatmul( TensorView* tv_a, TensorView* tv_b) { - NVF_CHECK(tv_a->getDataType().value() == tv_b->getDataType().value()); TensorView* out = newForMatmul(tv_a, tv_b); - // TensorView* out = newForMma(tv_a, tv_b, {1}); IrBuilder::create(out, tv_a, tv_b); return out; } From 0cf2e6b77fb099cb0f1c07cf9cc4af2956e38374 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 2 May 2024 22:42:16 +0000 Subject: [PATCH 08/35] remove scheduler" --- CMakeLists.txt | 1 - csrc/scheduler/all_schedulers.h | 1 - csrc/scheduler/expr_eval_sched.cpp | 58 -------------------------- csrc/scheduler/expr_eval_sched.h | 66 ------------------------------ csrc/scheduler/heuristic_types.cpp | 2 - csrc/scheduler/heuristic_types.h | 4 +- csrc/scheduler/registry.cpp | 12 ------ 7 files changed, 1 insertion(+), 143 deletions(-) delete mode 100644 csrc/scheduler/expr_eval_sched.cpp delete mode 100644 csrc/scheduler/expr_eval_sched.h diff --git a/CMakeLists.txt b/CMakeLists.txt index b573b0498af..8795f167edc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,7 +205,6 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp ${NVFUSER_SRCS_DIR}/scheduler/utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp - ${NVFUSER_SRCS_DIR}/scheduler/expr_eval_sched.cpp ${NVFUSER_SRCS_DIR}/serde/polymorphic_value.cpp ${NVFUSER_SRCS_DIR}/serde/utils.cpp ${NVFUSER_SRCS_DIR}/swizzle.cpp diff --git a/csrc/scheduler/all_schedulers.h b/csrc/scheduler/all_schedulers.h index 08a33343d6a..5bdf013c2de 100644 --- a/csrc/scheduler/all_schedulers.h +++ b/csrc/scheduler/all_schedulers.h @@ -6,7 +6,6 @@ */ // clang-format on #pragma once -#include #include #include #include diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp deleted file mode 100644 index 11a7989b6b0..00000000000 --- a/csrc/scheduler/expr_eval_sched.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on - -#include -#include -#include -#include - -namespace nvfuser { - -template -void vlog(const Args&... args) { - scheduler_debug_utils::log("[Expression Evaluator Scheduler] ", args...); -} - -ExprEvalScheduler::ExprEvalScheduler( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache) - : SchedulerEntry(heuristicType()) { - params_ = std::make_shared("", runtime_info.getIndexType()); -} - -//! Check if the no-op heuristics apply in given fusion -bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { - // Check if the fusion has matmul node and accept - if (fusion->outputs().size() == 1 && fusion->outputs().front()->definition()->isA()) { - return true; - } - scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), "Only accepts MatmulOp"); - return false; -} - -bool ExprEvalScheduler::canScheduleRunTime( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache) { - return true; -} - -void ExprEvalScheduler::schedule(Fusion* fusion) { - fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); -} - -void ExprEvalScheduler::computeHeuristics( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache) { - return; -} -} // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h deleted file mode 100644 index b56610cc686..00000000000 --- a/csrc/scheduler/expr_eval_sched.h +++ /dev/null @@ -1,66 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include - -namespace nvfuser { - -class Fusion; -class SchedulerRuntimeInfo; -class HeuristicSummary; - -//! ExprEval scheduler represents the case where we allocate outputs directly using EE. No code is generated. -class ExprEvalScheduler : public SchedulerEntry { - public: - explicit ExprEvalScheduler( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr); - - //! This scheduler only accepts matmul and linear nodes - static bool canScheduleCompileTime(Fusion* fusion); - - static bool canScheduleRunTime( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr); - - constexpr static ScheduleHeuristic heuristicType() { - return ScheduleHeuristic::ExprEval; - } - - void schedule(Fusion* fusion) override; - - private: - void computeHeuristics( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr); -}; - -//! Provides a dummy heuristic type to ensure -//! unified interface on ExprEval scheduler. -class ExprEvalHeuristic : public HeuristicParams { - public: - using HeuristicParams::HeuristicParams; - - size_t hash() const override { - return 0; - } - std::shared_ptr clone() const override { - return std::make_shared(); - } - bool sameAs(const std::shared_ptr& other) const override { - auto other_casted = std::dynamic_pointer_cast(other); - return other_casted != nullptr && other_casted->cparams == cparams; - }; -}; - -} // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/heuristic_types.cpp b/csrc/scheduler/heuristic_types.cpp index 3b1c5a7e97e..c85266685c7 100644 --- a/csrc/scheduler/heuristic_types.cpp +++ b/csrc/scheduler/heuristic_types.cpp @@ -29,8 +29,6 @@ std::string toString(ScheduleHeuristic sh) { return "transpose"; case ScheduleHeuristic::Matmul: return "matmul"; - case ScheduleHeuristic::ExprEval: - return "expr_eval"; case ScheduleHeuristic::None: return "none"; default: diff --git a/csrc/scheduler/heuristic_types.h b/csrc/scheduler/heuristic_types.h index 16dbe273e75..d2e402d9492 100644 --- a/csrc/scheduler/heuristic_types.h +++ b/csrc/scheduler/heuristic_types.h @@ -56,13 +56,11 @@ enum class ScheduleHeuristic { InnerOuterPersistent, OuterPersistent, Transpose, - ExprEval }; //! Define a schedule table to loop over all the heuristics in priority order. -constexpr std::array all_heuristics_in_priority_order = { +constexpr std::array all_heuristics_in_priority_order = { ScheduleHeuristic::NoOp, - ScheduleHeuristic::ExprEval, ScheduleHeuristic::Matmul, ScheduleHeuristic::Reduction, ScheduleHeuristic::Transpose, diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 462c8943e0c..a9680ecb2f5 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -206,8 +206,6 @@ bool checkCanSchedule( case ScheduleHeuristic::Matmul: return checkCanSchedule( fusion, runtime_info, data_cache); - case ScheduleHeuristic::ExprEval: - return checkCanSchedule(fusion, runtime_info, data_cache); default: NVF_ERROR(false, "unreachable"); return false; @@ -254,10 +252,6 @@ bool checkCanSchedule( scheduler_entry = std::make_unique(fusion, runtime_info, data_cache); break; - case ScheduleHeuristic::ExprEval: - scheduler_entry = - std::make_unique(fusion, runtime_info, data_cache); - break; default: NVF_ERROR(false, "unreachable"); } @@ -348,9 +342,6 @@ HeuristicSummary::HeuristicSummary( NVF_ERROR(canSchedule, "Could not schedule matmul (run time)"); break; } - case ScheduleHeuristic::ExprEval: - ExprEvalScheduler::canScheduleRunTime(fusion, runtime_info, this); - break; default: NVF_ERROR(false, "unknown heuristic"); } @@ -428,9 +419,6 @@ void HeuristicSummary::validate() const { // TODO: add a proper set of checks break; } - case ScheduleHeuristic::ExprEval: { - break; - } default: NVF_ERROR(false, "unknown heuristic"); } From bb5252612c02cca39d0b88e9162f547c317391be Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 2 May 2024 22:46:57 +0000 Subject: [PATCH 09/35] review comments --- csrc/dispatch.cpp | 1 + csrc/scheduler/heuristic_types.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/dispatch.cpp b/csrc/dispatch.cpp index c0df251e0b7..e1353ff772b 100644 --- a/csrc/dispatch.cpp +++ b/csrc/dispatch.cpp @@ -334,4 +334,5 @@ M(assoc_comm::FlattenedAssocCommOp) DISPATCH_FOR_ALL_KIR_VALS(M) DISPATCH_FOR_ALL_KIR_EXPRS(M) #undef M + } // namespace nvfuser diff --git a/csrc/scheduler/heuristic_types.h b/csrc/scheduler/heuristic_types.h index d2e402d9492..2ca5b8f779d 100644 --- a/csrc/scheduler/heuristic_types.h +++ b/csrc/scheduler/heuristic_types.h @@ -55,7 +55,7 @@ enum class ScheduleHeuristic { InnerPersistent, InnerOuterPersistent, OuterPersistent, - Transpose, + Transpose }; //! Define a schedule table to loop over all the heuristics in priority order. From 6c130d7f4697fd5f0e262c06eae6275a538c8b7e Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 2 May 2024 23:32:44 +0000 Subject: [PATCH 10/35] modify pairwise matching --- csrc/ops/arith.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 5bdd09fac9e..73b7fcc5565 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2728,6 +2728,9 @@ static TensorView* newForMatmul( } // namespace + +// TODO (Priya): This will be renamed to matmul once we are ready to modify the python API backend. +// Keeping separate for now, to avoid breaking tests in Thunder. TensorView* eagerMatmul( TensorView* tv_a, TensorView* tv_b) { From 6bb60316de3f2644d073a527bb56a60bd48c88dd Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Sat, 4 May 2024 07:37:34 +0000 Subject: [PATCH 11/35] modify matmul out allocation --- csrc/ops/arith.cpp | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 73b7fcc5565..7984566f331 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2695,33 +2695,27 @@ static TensorView* newForMatmul( auto orig_domain_b = TensorDomain::noReductions(tv_b->getMaybeRFactorDomain()); auto ndims_a = orig_domain_a.size(); auto ndims_b = orig_domain_b.size(); - + auto ndims_out = std::max(ndims_a, ndims_b); NVF_ERROR(ndims_a>=1 && ndims_b>= 1); - std::vector new_domain; - auto higher_dim_domain = ndims_a >= ndims_b ? orig_domain_a : orig_domain_b; - new_domain.reserve(higher_dim_domain.size()); - - if (ndims_a > 2 || ndims_b > 2) { - for (auto inx: c10::irange(higher_dim_domain.size() - 2)) { - new_domain.push_back(IterDomainBuilder(higher_dim_domain[inx]).resetSchedulingParams().build()); - } - } + std::vector out_domain(ndims_out, nullptr); - // Add M domain to output domain if present - if (orig_domain_a.size() > 1) { - const IterDomain* m_id = orig_domain_a[ndims_a-2]; - new_domain.push_back(IterDomainBuilder(m_id).resetSchedulingParams().build()); - } + int kpos_a = ndims_a - 1; + int kpos_b = ndims_b > 1 ? ndims_b - 2 : 0; - // Add N domain to output domain if present - if (orig_domain_b.size() > 1) { - const IterDomain* n_id = orig_domain_b[ndims_b-1]; - new_domain.push_back(IterDomainBuilder(n_id).resetSchedulingParams().build()); - } + for (int inx = ndims_out - 1, inx_a = ndims_a - 1, inx_b = ndims_b - 1; inx >= 0; inx--, inx_a--, inx_b--){ + std::vector input_ids; + if (inx_a >= 0 && inx_a != kpos_a){ + input_ids.emplace_back(orig_domain_a[inx_a]); + } + if (inx_b >= 0 && inx_b != kpos_b){ + input_ids.emplace_back(orig_domain_b[inx_b]); + } + out_domain[inx] = ops::outIterDomain(input_ids); + } TensorDomain* td = IrBuilder::create( - new_domain, TensorDomain::getContiguityFilledWith(new_domain, true)); + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); return IrBuilder::create(td, tv_a->dtype()); } From b75d72fc4db8b2a9dd34abff4c3bf14429c58f46 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 6 May 2024 22:14:05 +0000 Subject: [PATCH 12/35] move mapping logic to another function --- csrc/ops/arith.cpp | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 7984566f331..e6159d3108b 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2693,26 +2693,33 @@ static TensorView* newForMatmul( ) { auto orig_domain_a = TensorDomain::noReductions(tv_a->getMaybeRFactorDomain()); auto orig_domain_b = TensorDomain::noReductions(tv_b->getMaybeRFactorDomain()); + auto ndims_a = orig_domain_a.size(); auto ndims_b = orig_domain_b.size(); - auto ndims_out = std::max(ndims_a, ndims_b); - NVF_ERROR(ndims_a>=1 && ndims_b>= 1); - std::vector out_domain(ndims_out, nullptr); + NVF_ERROR(ndims_a >= 1 && ndims_b >= 1); + + auto ndims_out = std::max(ndims_a, ndims_b); + if (std::min(ndims_a, ndims_b) == 1) { + ndims_out = std::max(ndims_a, ndims_b) - 1; + } - int kpos_a = ndims_a - 1; - int kpos_b = ndims_b > 1 ? ndims_b - 2 : 0; + std::vector out_domain(ndims_out, nullptr); - for (int inx = ndims_out - 1, inx_a = ndims_a - 1, inx_b = ndims_b - 1; inx >= 0; inx--, inx_a--, inx_b--){ + const auto& mapping_a = ops::mapMatmulIterDomains(orig_domain_a, true, ndims_out); + const auto& mapping_b = ops::mapMatmulIterDomains(orig_domain_b, false, ndims_out); + + for (auto inx: c10::irange(ndims_out)){ std::vector input_ids; - if (inx_a >= 0 && inx_a != kpos_a){ - input_ids.emplace_back(orig_domain_a[inx_a]); + input_ids.reserve(2); + if (mapping_a[inx] != nullptr){ + input_ids.emplace_back(mapping_a[inx]); } - if (inx_b >= 0 && inx_b != kpos_b){ - input_ids.emplace_back(orig_domain_b[inx_b]); + if (mapping_b[inx] != nullptr){ + input_ids.emplace_back(mapping_b[inx]); } - out_domain[inx] = ops::outIterDomain(input_ids); - } + out_domain[inx] = ops::newOutputIterDomain(input_ids); + } TensorDomain* td = IrBuilder::create( out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); From 54314bf0f708a70ba7ad633ac9912d8184777157 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 6 May 2024 22:30:54 +0000 Subject: [PATCH 13/35] use mapping in root domain --- csrc/root_domain_map.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index b82213fea51..48d9e3816bf 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include From 22031f2fb6d48669f8e88a8d394e78678dc7c0b5 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 6 May 2024 22:48:24 +0000 Subject: [PATCH 14/35] comment --- csrc/ops/arith.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index e6159d3108b..4c2ceba4a4e 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2706,8 +2706,8 @@ static TensorView* newForMatmul( std::vector out_domain(ndims_out, nullptr); - const auto& mapping_a = ops::mapMatmulIterDomains(orig_domain_a, true, ndims_out); - const auto& mapping_b = ops::mapMatmulIterDomains(orig_domain_b, false, ndims_out); + const auto& mapping_a = ops::mapMatmulOpIterDomains(orig_domain_a, true, ndims_out); + const auto& mapping_b = ops::mapMatmulOpIterDomains(orig_domain_b, false, ndims_out); for (auto inx: c10::irange(ndims_out)){ std::vector input_ids; From 76fd527ad5e2923c0ad942c2f44ea186a6084507 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 6 May 2024 22:56:19 +0000 Subject: [PATCH 15/35] add dot product case --- csrc/ops/arith.cpp | 57 ---------------------------------------------- csrc/ops/arith.h | 4 ---- 2 files changed, 61 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 4c2ceba4a4e..b7f0f7235a3 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2684,61 +2684,4 @@ TensorView* tensor(Val* val) { return out; } -namespace { - -//! Create new output for matmul -static TensorView* newForMatmul( - TensorView* tv_a, - TensorView* tv_b -) { - auto orig_domain_a = TensorDomain::noReductions(tv_a->getMaybeRFactorDomain()); - auto orig_domain_b = TensorDomain::noReductions(tv_b->getMaybeRFactorDomain()); - - auto ndims_a = orig_domain_a.size(); - auto ndims_b = orig_domain_b.size(); - - NVF_ERROR(ndims_a >= 1 && ndims_b >= 1); - - auto ndims_out = std::max(ndims_a, ndims_b); - if (std::min(ndims_a, ndims_b) == 1) { - ndims_out = std::max(ndims_a, ndims_b) - 1; - } - - std::vector out_domain(ndims_out, nullptr); - - const auto& mapping_a = ops::mapMatmulOpIterDomains(orig_domain_a, true, ndims_out); - const auto& mapping_b = ops::mapMatmulOpIterDomains(orig_domain_b, false, ndims_out); - - for (auto inx: c10::irange(ndims_out)){ - std::vector input_ids; - input_ids.reserve(2); - if (mapping_a[inx] != nullptr){ - input_ids.emplace_back(mapping_a[inx]); - } - if (mapping_b[inx] != nullptr){ - input_ids.emplace_back(mapping_b[inx]); - } - out_domain[inx] = ops::newOutputIterDomain(input_ids); - } - - TensorDomain* td = IrBuilder::create( - out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); - - return IrBuilder::create(td, tv_a->dtype()); -} - -} // namespace - - -// TODO (Priya): This will be renamed to matmul once we are ready to modify the python API backend. -// Keeping separate for now, to avoid breaking tests in Thunder. -TensorView* eagerMatmul( - TensorView* tv_a, - TensorView* tv_b) { - NVF_CHECK(tv_a->getDataType().value() == tv_b->getDataType().value()); - TensorView* out = newForMatmul(tv_a, tv_b); - IrBuilder::create(out, tv_a, tv_b); - return out; -} - } // namespace nvfuser diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 904444fa530..17cf78074c3 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -806,8 +806,4 @@ NVF_API TensorView* tensor(const std::vector& vals) { return tensor(IrBuilder::arrayExpr(vals)); } -NVF_API TensorView* eagerMatmul( - TensorView* tv_a, - TensorView* tv_b); - } // namespace nvfuser From dc45b685401cc9dac90506f5b880cfa4e33955ea Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 6 May 2024 23:00:23 +0000 Subject: [PATCH 16/35] format --- csrc/root_domain_map.cpp | 1 + csrc/scheduler/heuristic_types.h | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 48d9e3816bf..69116f4c09a 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include diff --git a/csrc/scheduler/heuristic_types.h b/csrc/scheduler/heuristic_types.h index 2ca5b8f779d..e30169141b4 100644 --- a/csrc/scheduler/heuristic_types.h +++ b/csrc/scheduler/heuristic_types.h @@ -67,8 +67,7 @@ constexpr std::array all_heuristics_in_priority_order = { ScheduleHeuristic::PointWise, ScheduleHeuristic::InnerPersistent, ScheduleHeuristic::OuterPersistent, - ScheduleHeuristic::InnerOuterPersistent - }; + ScheduleHeuristic::InnerOuterPersistent}; std::string toString(ScheduleHeuristic sh); From 5cb9fe572dd98abc57192f0c0225520f4452b3e9 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 7 May 2024 19:29:44 +0000 Subject: [PATCH 17/35] 1D case, review comments --- csrc/ops/utils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 753440f51dc..f7c7bd7a3b6 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -12,6 +12,7 @@ #include #include #include +#include #include From 2047629274a49b2b5bfe6f9d31878122de16a9a5 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 7 May 2024 23:52:33 +0000 Subject: [PATCH 18/35] move common code --- csrc/root_domain_map.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 69116f4c09a..68fe0bea616 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -99,6 +99,22 @@ std::pair getIndexedDomainInfo( return std::make_pair(indexed_id, has_consumer_id); } +// Add key-value iterdomain pair to the map. +void updatePairwiseRootDomainMap ( + IterDomain* map_key_id, + IterDomain* map_value_id, + const std::unordered_set& root_dims_to_map, + bool producer_to_consumer, + std::unordered_map& dom_map +) { + if (!producer_to_consumer) { + std::swap(map_key_id, map_value_id); + } + if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { + dom_map.insert(std::make_pair(map_key_id, map_value_id)); + } +} + } // namespace std::unordered_map PairwiseRootDomainMap::map( From bc860191642edb841d39207eaba157d94435f26b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 7 May 2024 23:55:41 +0000 Subject: [PATCH 19/35] format --- csrc/ops/utils.h | 1 - csrc/root_domain_map.cpp | 13 ++++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index f7c7bd7a3b6..753440f51dc 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -12,7 +12,6 @@ #include #include #include -#include #include diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 68fe0bea616..fd905e1f20b 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -100,13 +100,12 @@ std::pair getIndexedDomainInfo( } // Add key-value iterdomain pair to the map. -void updatePairwiseRootDomainMap ( - IterDomain* map_key_id, - IterDomain* map_value_id, - const std::unordered_set& root_dims_to_map, - bool producer_to_consumer, - std::unordered_map& dom_map -) { +void updatePairwiseRootDomainMap( + IterDomain* map_key_id, + IterDomain* map_value_id, + const std::unordered_set& root_dims_to_map, + bool producer_to_consumer, + std::unordered_map& dom_map) { if (!producer_to_consumer) { std::swap(map_key_id, map_value_id); } From 6d14c3200d4b31f5c2e111a3d3d75b9eb28e2afa Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 9 May 2024 17:16:09 +0000 Subject: [PATCH 20/35] review comments --- csrc/ops/composite.cpp | 4 ++++ csrc/root_domain_map.cpp | 15 --------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 294d48e829f..6245e079d9b 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -370,6 +370,10 @@ TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b) { " and ", tv_b->dtype()); + // Note: torch.matmul reference does not restrict the inputs to the same dtype, but it fails for different input dtypes. + // This condition may potentially be modified. The following condition should change accordingly. + NVF_CHECK(tv_a->dtype() == tv_b->dtype(), "Expected A and B dtypes to have the same dtype, got: ", tv_a->dtype(), " and ", tv_b->dtype()); + if (tv_a->nDims() == 1 && tv_b->nDims() == 1) { // Return the dot product instead of creating the MatmulOp. // Cast back the output if needed since torch.matmul maintains input dtype. diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index fd905e1f20b..69116f4c09a 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -99,21 +99,6 @@ std::pair getIndexedDomainInfo( return std::make_pair(indexed_id, has_consumer_id); } -// Add key-value iterdomain pair to the map. -void updatePairwiseRootDomainMap( - IterDomain* map_key_id, - IterDomain* map_value_id, - const std::unordered_set& root_dims_to_map, - bool producer_to_consumer, - std::unordered_map& dom_map) { - if (!producer_to_consumer) { - std::swap(map_key_id, map_value_id); - } - if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { - dom_map.insert(std::make_pair(map_key_id, map_value_id)); - } -} - } // namespace std::unordered_map PairwiseRootDomainMap::map( From 9910676b6478ed3f47f4eb5e058dde235ad2ab85 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 9 May 2024 17:30:03 +0000 Subject: [PATCH 21/35] lin --- csrc/ops/composite.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 6245e079d9b..294d48e829f 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -370,10 +370,6 @@ TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b) { " and ", tv_b->dtype()); - // Note: torch.matmul reference does not restrict the inputs to the same dtype, but it fails for different input dtypes. - // This condition may potentially be modified. The following condition should change accordingly. - NVF_CHECK(tv_a->dtype() == tv_b->dtype(), "Expected A and B dtypes to have the same dtype, got: ", tv_a->dtype(), " and ", tv_b->dtype()); - if (tv_a->nDims() == 1 && tv_b->nDims() == 1) { // Return the dot product instead of creating the MatmulOp. // Cast back the output if needed since torch.matmul maintains input dtype. From b9613f064fb6b36ddb5865f9d4a07ce8287eb934 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 2 May 2024 22:36:42 +0000 Subject: [PATCH 22/35] add to scheduler --- CMakeLists.txt | 1 + csrc/scheduler/all_schedulers.h | 1 + csrc/scheduler/expr_eval_sched.cpp | 58 ++++++++++++++++++++++++++ csrc/scheduler/expr_eval_sched.h | 66 ++++++++++++++++++++++++++++++ csrc/scheduler/heuristic_types.cpp | 2 + csrc/scheduler/heuristic_types.h | 6 ++- csrc/scheduler/registry.cpp | 12 ++++++ 7 files changed, 144 insertions(+), 2 deletions(-) create mode 100644 csrc/scheduler/expr_eval_sched.cpp create mode 100644 csrc/scheduler/expr_eval_sched.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8795f167edc..b573b0498af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,6 +205,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp ${NVFUSER_SRCS_DIR}/scheduler/utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp + ${NVFUSER_SRCS_DIR}/scheduler/expr_eval_sched.cpp ${NVFUSER_SRCS_DIR}/serde/polymorphic_value.cpp ${NVFUSER_SRCS_DIR}/serde/utils.cpp ${NVFUSER_SRCS_DIR}/swizzle.cpp diff --git a/csrc/scheduler/all_schedulers.h b/csrc/scheduler/all_schedulers.h index 5bdf013c2de..08a33343d6a 100644 --- a/csrc/scheduler/all_schedulers.h +++ b/csrc/scheduler/all_schedulers.h @@ -6,6 +6,7 @@ */ // clang-format on #pragma once +#include #include #include #include diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp new file mode 100644 index 00000000000..11a7989b6b0 --- /dev/null +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -0,0 +1,58 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include + +namespace nvfuser { + +template +void vlog(const Args&... args) { + scheduler_debug_utils::log("[Expression Evaluator Scheduler] ", args...); +} + +ExprEvalScheduler::ExprEvalScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) + : SchedulerEntry(heuristicType()) { + params_ = std::make_shared("", runtime_info.getIndexType()); +} + +//! Check if the no-op heuristics apply in given fusion +bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { + // Check if the fusion has matmul node and accept + if (fusion->outputs().size() == 1 && fusion->outputs().front()->definition()->isA()) { + return true; + } + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), "Only accepts MatmulOp"); + return false; +} + +bool ExprEvalScheduler::canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + return true; +} + +void ExprEvalScheduler::schedule(Fusion* fusion) { + fusion->aliasOutputToInput( + fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); +} + +void ExprEvalScheduler::computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + return; +} +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h new file mode 100644 index 00000000000..b56610cc686 --- /dev/null +++ b/csrc/scheduler/expr_eval_sched.h @@ -0,0 +1,66 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { + +class Fusion; +class SchedulerRuntimeInfo; +class HeuristicSummary; + +//! ExprEval scheduler represents the case where we allocate outputs directly using EE. No code is generated. +class ExprEvalScheduler : public SchedulerEntry { + public: + explicit ExprEvalScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); + + //! This scheduler only accepts matmul and linear nodes + static bool canScheduleCompileTime(Fusion* fusion); + + static bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); + + constexpr static ScheduleHeuristic heuristicType() { + return ScheduleHeuristic::ExprEval; + } + + void schedule(Fusion* fusion) override; + + private: + void computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); +}; + +//! Provides a dummy heuristic type to ensure +//! unified interface on ExprEval scheduler. +class ExprEvalHeuristic : public HeuristicParams { + public: + using HeuristicParams::HeuristicParams; + + size_t hash() const override { + return 0; + } + std::shared_ptr clone() const override { + return std::make_shared(); + } + bool sameAs(const std::shared_ptr& other) const override { + auto other_casted = std::dynamic_pointer_cast(other); + return other_casted != nullptr && other_casted->cparams == cparams; + }; +}; + +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/heuristic_types.cpp b/csrc/scheduler/heuristic_types.cpp index c85266685c7..3b1c5a7e97e 100644 --- a/csrc/scheduler/heuristic_types.cpp +++ b/csrc/scheduler/heuristic_types.cpp @@ -29,6 +29,8 @@ std::string toString(ScheduleHeuristic sh) { return "transpose"; case ScheduleHeuristic::Matmul: return "matmul"; + case ScheduleHeuristic::ExprEval: + return "expr_eval"; case ScheduleHeuristic::None: return "none"; default: diff --git a/csrc/scheduler/heuristic_types.h b/csrc/scheduler/heuristic_types.h index e30169141b4..c5f48c9e6c0 100644 --- a/csrc/scheduler/heuristic_types.h +++ b/csrc/scheduler/heuristic_types.h @@ -55,12 +55,14 @@ enum class ScheduleHeuristic { InnerPersistent, InnerOuterPersistent, OuterPersistent, - Transpose + Transpose, + ExprEval }; //! Define a schedule table to loop over all the heuristics in priority order. -constexpr std::array all_heuristics_in_priority_order = { +constexpr std::array all_heuristics_in_priority_order = { ScheduleHeuristic::NoOp, + ScheduleHeuristic::ExprEval, ScheduleHeuristic::Matmul, ScheduleHeuristic::Reduction, ScheduleHeuristic::Transpose, diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index a9680ecb2f5..eca8183b4d8 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -206,6 +206,8 @@ bool checkCanSchedule( case ScheduleHeuristic::Matmul: return checkCanSchedule( fusion, runtime_info, data_cache); + case ScheduleHeuristic::ExprEval: + return checkCanSchedule(fusion, runtime_info, data_cache); default: NVF_ERROR(false, "unreachable"); return false; @@ -252,6 +254,10 @@ bool checkCanSchedule( scheduler_entry = std::make_unique(fusion, runtime_info, data_cache); break; + case ScheduleHeuristic::ExprEval: + scheduler_entry = + std::make_unique(fusion, runtime_info, data_cache); + break; default: NVF_ERROR(false, "unreachable"); } @@ -342,6 +348,9 @@ HeuristicSummary::HeuristicSummary( NVF_ERROR(canSchedule, "Could not schedule matmul (run time)"); break; } + case ScheduleHeuristic::ExprEval: + ExprEvalScheduler::canScheduleRunTime(fusion, runtime_info, this); + break; default: NVF_ERROR(false, "unknown heuristic"); } @@ -418,6 +427,9 @@ void HeuristicSummary::validate() const { case ScheduleHeuristic::Matmul: { // TODO: add a proper set of checks break; + } + case ScheduleHeuristic::ExprEval: { + break; } default: NVF_ERROR(false, "unknown heuristic"); From 6a059ab320295a18de27f5dff95be93570703547 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 7 May 2024 21:04:40 +0000 Subject: [PATCH 23/35] check for matmul op --- csrc/scheduler/expr_eval_sched.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 11a7989b6b0..d6be8e623f3 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -29,7 +29,8 @@ ExprEvalScheduler::ExprEvalScheduler( //! Check if the no-op heuristics apply in given fusion bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { // Check if the fusion has matmul node and accept - if (fusion->outputs().size() == 1 && fusion->outputs().front()->definition()->isA()) { + auto exprs = fusion->exprs(); + if (exprs->size() == 1 && exprs.front()->isA()){ return true; } scheduler_debug_utils::canScheduleRejectReason( From 5b9a4a3dfa4c2f88b4888d814c7d7333ded9ec5f Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Thu, 9 May 2024 13:31:31 -0400 Subject: [PATCH 24/35] Update csrc/scheduler/expr_eval_sched.cpp Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> --- csrc/scheduler/expr_eval_sched.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index d6be8e623f3..702008bbf80 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -34,7 +34,7 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return true; } scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), "Only accepts MatmulOp"); + heuristicType(), "Fusion must contain a single expression of type MatmulOp"); return false; } From 1ad32eaee8b97945ab9f336a958df5446753e81b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 9 May 2024 22:28:15 +0000 Subject: [PATCH 25/35] remove unused functions, add defaults to heuristic param --- csrc/scheduler/expr_eval_sched.cpp | 18 +++--------------- csrc/scheduler/expr_eval_sched.h | 27 ++------------------------- csrc/scheduler/heuristic.h | 17 +++++++++++++---- 3 files changed, 18 insertions(+), 44 deletions(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 702008bbf80..cd448f87275 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -13,24 +13,18 @@ namespace nvfuser { -template -void vlog(const Args&... args) { - scheduler_debug_utils::log("[Expression Evaluator Scheduler] ", args...); -} - ExprEvalScheduler::ExprEvalScheduler( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) : SchedulerEntry(heuristicType()) { - params_ = std::make_shared("", runtime_info.getIndexType()); + params_ = std::make_shared(); } -//! Check if the no-op heuristics apply in given fusion +// Check if the fusion has a single MatmulOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { - // Check if the fusion has matmul node and accept auto exprs = fusion->exprs(); - if (exprs->size() == 1 && exprs.front()->isA()){ + if (exprs.size() == 1 && exprs.front()->isA()){ return true; } scheduler_debug_utils::canScheduleRejectReason( @@ -50,10 +44,4 @@ void ExprEvalScheduler::schedule(Fusion* fusion) { fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); } -void ExprEvalScheduler::computeHeuristics( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache) { - return; -} } // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h index b56610cc686..09731f77c0b 100644 --- a/csrc/scheduler/expr_eval_sched.h +++ b/csrc/scheduler/expr_eval_sched.h @@ -16,7 +16,7 @@ class Fusion; class SchedulerRuntimeInfo; class HeuristicSummary; -//! ExprEval scheduler represents the case where we allocate outputs directly using EE. No code is generated. +// ExprEval scheduler represents the case where we allocate outputs directly using EE. No code is generated. class ExprEvalScheduler : public SchedulerEntry { public: explicit ExprEvalScheduler( @@ -24,7 +24,7 @@ class ExprEvalScheduler : public SchedulerEntry { SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr); - //! This scheduler only accepts matmul and linear nodes + // This scheduler only accepts MatmulOp. static bool canScheduleCompileTime(Fusion* fusion); static bool canScheduleRunTime( @@ -37,30 +37,7 @@ class ExprEvalScheduler : public SchedulerEntry { } void schedule(Fusion* fusion) override; - - private: - void computeHeuristics( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr); }; -//! Provides a dummy heuristic type to ensure -//! unified interface on ExprEval scheduler. -class ExprEvalHeuristic : public HeuristicParams { - public: - using HeuristicParams::HeuristicParams; - - size_t hash() const override { - return 0; - } - std::shared_ptr clone() const override { - return std::make_shared(); - } - bool sameAs(const std::shared_ptr& other) const override { - auto other_casted = std::dynamic_pointer_cast(other); - return other_casted != nullptr && other_casted->cparams == cparams; - }; -}; } // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/heuristic.h b/csrc/scheduler/heuristic.h index 92ba567fc41..db761bbbc0c 100644 --- a/csrc/scheduler/heuristic.h +++ b/csrc/scheduler/heuristic.h @@ -25,11 +25,20 @@ class HeuristicParams : public PolymorphicBase { return "Undefined Heuristic Params"; } - virtual size_t hash() const = 0; - - virtual bool sameAs(const std::shared_ptr& other) const = 0; + virtual size_t hash() const { + return 0; + }; + + virtual bool sameAs(const std::shared_ptr& other) const { + if (!other->isA()) { + return false; + } + return other->cparams == cparams; + } - virtual std::shared_ptr clone() const = 0; + virtual std::shared_ptr clone() const { + return std::make_shared(); + } HeuristicParams() = default; HeuristicParams(std::string tag, PrimDataType index_type) From a1b22e20f60fb515ff403c721b8acd337c04bc9c Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 10 May 2024 02:03:53 +0000 Subject: [PATCH 26/35] change scheduler order --- csrc/scheduler/expr_eval_sched.cpp | 2 +- csrc/scheduler/heuristic_types.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index cd448f87275..8ba08d3f164 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -18,7 +18,7 @@ ExprEvalScheduler::ExprEvalScheduler( SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) : SchedulerEntry(heuristicType()) { - params_ = std::make_shared(); + params_ = std::make_shared("", runtime_info.getIndexType()); } // Check if the fusion has a single MatmulOp node diff --git a/csrc/scheduler/heuristic_types.h b/csrc/scheduler/heuristic_types.h index c5f48c9e6c0..bbc3949c8d6 100644 --- a/csrc/scheduler/heuristic_types.h +++ b/csrc/scheduler/heuristic_types.h @@ -61,8 +61,8 @@ enum class ScheduleHeuristic { //! Define a schedule table to loop over all the heuristics in priority order. constexpr std::array all_heuristics_in_priority_order = { - ScheduleHeuristic::NoOp, ScheduleHeuristic::ExprEval, + ScheduleHeuristic::NoOp, ScheduleHeuristic::Matmul, ScheduleHeuristic::Reduction, ScheduleHeuristic::Transpose, From bcd5791b1b960f7893c2add3f13180e4a1299a27 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 13 May 2024 20:25:25 +0000 Subject: [PATCH 27/35] fix comparison --- csrc/root_domain_map.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 69116f4c09a..9b1a1664ad6 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -141,7 +141,7 @@ std::unordered_map PairwiseRootDomainMap::map( if (MatmulOp* op = dynamic_cast(consumer_tv_->definition())) { // Check if the producer is lhs/rhs input MatmulRole input_role = - producer->sameAs(op->inA()) ? MatmulRole::INPUT_A : MatmulRole::INPUT_B; + producer->sameAs(op->inA()->as()->domain()) ? MatmulRole::INPUT_A : MatmulRole::INPUT_B; auto out_size = consumer_root.size(); // For MatmulOp, the input iterdomains at a given index do not necessarily @@ -158,7 +158,9 @@ std::unordered_map PairwiseRootDomainMap::map( for (auto inx : c10::irange(out_size)) { IterDomain* map_key_id = aligned_producer_id.at(inx); IterDomain* map_value_id = consumer_root.at(inx); - updatePairwiseRootDomainMap(map_key_id, map_value_id); + if (map_key_id != nullptr){ + updatePairwiseRootDomainMap(map_key_id, map_value_id); + } } return dom_map; } From 71834085468d7934ee11a26c98ab9bf89e8f31a0 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 13 May 2024 21:05:34 +0000 Subject: [PATCH 28/35] chech broadcast and symbolic conditions --- csrc/root_domain_map.cpp | 22 +++++++++++++++++----- tests/cpp/test_matmul_aten_evaluation.cpp | 20 ++++---------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 9b1a1664ad6..16af6579fc9 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -152,16 +152,28 @@ std::unordered_map PairwiseRootDomainMap::map( // input and output for index=2 // 2. `B, M, K] x [K, N] -> [B, M, N]`: For input B, the second iterdomain // maps to the third output iterdomain. - const std::vector& aligned_producer_id = + const std::vector& aligned_producer_ids = ops::mapMatmulOpIterDomains(producer_root, input_role, out_size); for (auto inx : c10::irange(out_size)) { - IterDomain* map_key_id = aligned_producer_id.at(inx); - IterDomain* map_value_id = consumer_root.at(inx); - if (map_key_id != nullptr){ - updatePairwiseRootDomainMap(map_key_id, map_value_id); + IterDomain* producer_id = aligned_producer_ids.at(inx); + IterDomain* consumer_id = consumer_root.at(inx); + + if (producer_id == nullptr){ + continue; + } + if (!map_broadcast_ && + producer_id->isBroadcast() != consumer_id->isBroadcast()){ + continue; + } + if (!map_symbolic_ && + (producer_id->isSymbolic() || consumer_id->isSymbolic()) && + (!producer_id->extent()->sameAs(consumer_id->extent()))) { + continue; } + updatePairwiseRootDomainMap(producer_id, consumer_id); } + return dom_map; } diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 11fe6c33f9b..9e73c33c036 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -427,14 +427,8 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); at::Tensor out_ref = at::matmul(t0, t1); - FusionExecutor fe; - fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); - fe.compileFusion(fusion.get(), {t0, t1}); - auto out = fe.runFusion({t0, t1}); - - // Verify that fusion compilation was skipped. - EXPECT_FALSE(fe.hasCompiledKernel()); + FusionExecutorCache fec(std::move(fusion)); + auto out = fec.runFusionWithInputs({t0, t1}); EXPECT_TRUE(at::allclose(out[0], out_ref)); } @@ -457,14 +451,8 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeSymbolic) { at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); at::Tensor out_ref = at::matmul(t0, t1); - FusionExecutor fe; - fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); - fe.compileFusion(fusion.get(), {t0, t1}); - auto out = fe.runFusion({t0, t1}); - - // Verify that fusion compilation was skipped. - EXPECT_FALSE(fe.hasCompiledKernel()); + FusionExecutorCache fec(std::move(fusion)); + auto out = fec.runFusionWithInputs({t0, t1}); EXPECT_TRUE(at::allclose(out[0], out_ref)); } From ffd48dfc09ecfb791cedf631020a0259de5a329a Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 13 May 2024 21:12:40 +0000 Subject: [PATCH 29/35] rename API --- csrc/ops/composite.cpp | 38 +---------------------- csrc/ops/composite.h | 12 +------ tests/cpp/test_matmul_aten_evaluation.cpp | 4 +-- 3 files changed, 4 insertions(+), 50 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 294d48e829f..d44153e8800 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -54,42 +54,6 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { return dx; } -// This function will add a castOp to the output of the matrix multiplication -// The implementation of linear can use this but will skip the cast (set cast -// flag as false) and add the bias. -TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input) { - NVF_CHECK( - a->nDims() == b->nDims(), - "The number of dimension of A and B do not match"); - // TODO: We'll need to suppor nDims == 3 for bmm. - NVF_CHECK( - a->nDims() == 2, - "Only 2-D Tensors are supported, in the future we'll support 3-D as well!"); - - std::vector bcast_dims(a->nDims() + 1, false); - // A: [M, K, Bcast] - // B: [Bcast, K, N] - bcast_dims.at(bcast_dims.size() - 1) = true; - auto* tv0b = broadcast(a, bcast_dims); - bcast_dims.at(bcast_dims.size() - 1) = false; - bcast_dims.at(bcast_dims.size() - 3) = true; - auto* tv1b = broadcast(b, bcast_dims); - - NVF_CHECK( - a->getDataType().value() == b->getDataType().value(), - "data types of inputs to matmul don't match"); - auto* output = fusedMultiplySum(tv0b, tv1b, {-2}); - if (cast_output_to_input) { - // For matmul, the output dtype should match input. - return maybeCastOp(a->getDataType().value(), output); - } - return output; -} - -TensorView* matmul(TensorView* a, TensorView* b) { - return matmul(a, b, true /* cast output to input dtype */); -} - TensorView* linear(TensorView* a, TensorView* b, TensorView* bias) { // TODO: Support 1+ dimensional A. NVF_CHECK( @@ -351,7 +315,7 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { // TODO (Priya): This will be renamed to matmul once we are ready to modify the // python API backend. Keeping separate for now, to avoid breaking tests in // Thunder. -TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b) { +TensorView* matmul(TensorView* tv_a, TensorView* tv_b) { NVF_CHECK( tv_a->nDims() > 0 && tv_b->nDims() > 0, "Expected inputs to be atleast 1D, got: ", diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index 365b0144304..053553848d7 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -47,16 +47,6 @@ NVF_API LstmResult lstm( TensorView* cell_x, TensorView* out_x); -// Matmul function which takes in tensors with the shapes -// A[M,K] B[K,N], but the tensors may have different layouts -// via strides. All restrictions from the matmul APIs also -// apply here. -TensorView* matmul(TensorView* a, TensorView* b); -// This second matmul function is not exposed via -// the Python interface, but it does the guts of the work and -// can be used to create mamtuls without a cast operation following it. -TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input); - // Linear functions which takes in two tensors of shapes A[M,K] and // B[N,K]. Takes in a options bias of shape [N] and performs // out = A * B_Transpose + bias. The output dtype matches the dtype @@ -81,6 +71,6 @@ TensorView* leaky_relu(TensorView* x, Val* negative_slope); NVF_API TensorView* view_as_real(TensorView* x); -TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b); +TensorView* matmul(TensorView* tv_a, TensorView* tv_b); } // namespace nvfuser diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 9e73c33c036..e9c0b190f2c 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -417,7 +417,7 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { auto tv0 = makeConcreteTensor(a_shape, DataType::Half); auto tv1 = makeConcreteTensor(b_shape, DataType::Half); - auto tv2 = eagerMatmul(tv0, tv1); + auto tv2 = matmul(tv0, tv1); fusion->addInput(tv0); fusion->addInput(tv1); @@ -441,7 +441,7 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeSymbolic) { auto tv0 = makeSymbolicTensor(a_shape, DataType::Half); auto tv1 = makeSymbolicTensor(b_shape, DataType::Half); - auto tv2 = eagerMatmul(tv0, tv1); + auto tv2 = matmul(tv0, tv1); fusion->addInput(tv0); fusion->addInput(tv1); From 8f80548c319fb3c3db60b23115fd77d2dc0c25d2 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 13 May 2024 21:34:28 +0000 Subject: [PATCH 30/35] modify matmul generator to use cases from Thunder --- tests/python/pytest_input_generators.py | 37 +++++++++++++++++++++---- tests/python/pytest_opinfos.py | 7 +++-- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/tests/python/pytest_input_generators.py b/tests/python/pytest_input_generators.py index 53220008b90..a8c1702a3e3 100644 --- a/tests/python/pytest_input_generators.py +++ b/tests/python/pytest_input_generators.py @@ -1487,8 +1487,37 @@ def vector_at_error_generator( make_arg(error_case["tensor_shape"]), index=error_case["index"] ), error_type, error_msg +def matmul_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs): + make_arg = partial( + make_tensor, + dtype=dtype, + device="cuda", + low=None, + high=None, + requires_grad=requires_grad, + ) + + B = 64 + M = 512 + N = 256 + K = 32 -def matmul_or_linear_input_generator( + # shape_a, shape_b + cases = ( + ((K,), (K,)), + ((K,), (K, N)), + ((M, K), (K,)), + ((K,), (B, K, N)), + ((B, M, K), (K,)), + ((M, K), (K, N)), + ((B, M, K), (B, K, N)), + ((B, B, M, K), (B, B, K, N)), + ) + + for shape_a, shape_b in cases: + yield SampleInput(make_arg(shape_a), make_arg(shape_b)) + +def linear_input_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): make_arg = partial( @@ -1507,17 +1536,13 @@ def multiply_range(maximum, step): map(pow, itertools.repeat(step, num_steps), range(1, num_steps + 1)) ) - is_linear = op.name == "linear" - # Ranges of tensor sizes: 8, 64, 512, 4096, 32768, ... # Use a Cartesian product to create a wide range of matrix shapes # I'll stop at 512 as possible numerical difference may show up. M, N, K = itertools.repeat(multiply_range(512, 8), 3) for M, N, K in itertools.product(M, N, K): lhs_shape = (M, K) - rhs_shape = (N, K) if is_linear else (K, N) + rhs_shape = (N, K) yield ( SampleInput(make_arg(lhs_shape), make_arg(rhs_shape), make_arg((N,))) - if is_linear - else SampleInput(make_arg(lhs_shape), make_arg(rhs_shape)) ) diff --git a/tests/python/pytest_opinfos.py b/tests/python/pytest_opinfos.py index 740676f934b..348da16d171 100644 --- a/tests/python/pytest_opinfos.py +++ b/tests/python/pytest_opinfos.py @@ -48,7 +48,8 @@ var_mean_generator, vector_at_error_generator, where_error_generator, - matmul_or_linear_input_generator, + matmul_input_generator, + linear_input_generator ) from pytest_utils import ( bool_int_dtypes, @@ -1115,7 +1116,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 else (torch.float16,) ), - sample_input_generator=matmul_or_linear_input_generator, + sample_input_generator=matmul_input_generator, reference=torch.matmul, ) matmul_ops.append(matmul_opinfo) @@ -1131,7 +1132,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 else (torch.float16,) ), - sample_input_generator=matmul_or_linear_input_generator, + sample_input_generator=linear_input_generator, reference=torch.nn.functional.linear, ) linear_ops.append(linear_opinfo) From a11c4a9b6da5048627d7b4f9d4d3552a42fc5ea5 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 13 May 2024 23:48:40 +0000 Subject: [PATCH 31/35] refactor code --- csrc/root_domain_map.cpp | 80 +++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 47 deletions(-) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 16af6579fc9..47e039014f0 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -125,13 +125,42 @@ std::unordered_map PairwiseRootDomainMap::map( TensorDomain::noReductions(producer->maybeRFactor()); const auto& consumer_root = consumer->root(); - // Add key-value iterdomain pair to the map. + // Check following conditions and add key-value iterdomain pair to domain map: + // 1. Do not map broadcast ID to non-broadcast ID unless map_broadcast_ = true. + // 2. Do not map Symbolic ID if the extents are not identical unless map_symbolic_ = true. auto updatePairwiseRootDomainMap = - [&root_dims_to_map, producer_to_consumer, &dom_map]( - IterDomain* map_key_id, IterDomain* map_value_id) { + [&]( + IterDomain* producer_id, IterDomain* consumer_id) { + if (!map_broadcast_ && + producer_id->isBroadcast() != consumer_id->isBroadcast()){ + return; + } + + // Condition: At least one ID is symbolic. + // + // If map_symbolic_ is true: + // Map these IDs regardless of other considerations. + // + // If map_symbolic_ is false (default): + // Map these only if their extents are identical. IterType::Symbolic + // reflects that the extent might evaluate to 1 for some inputs, in which + // case it may be valid to use those domains in a broadcast op. If the + // extents are exactly the same between two aligned IterDomains, the + // Symbolic one will be concretized to the same IterType as the other, so + // they should be mapped with one another. + if (!map_symbolic_ && + (producer_id->isSymbolic() || consumer_id->isSymbolic()) && + (!producer_id->extent()->sameAs(consumer_id->extent()))) { + return; + } + + IterDomain* map_key_id = producer_id; + IterDomain* map_value_id = consumer_id; + if (!producer_to_consumer) { std::swap(map_key_id, map_value_id); } + if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { dom_map.insert(std::make_pair(map_key_id, map_value_id)); } @@ -158,19 +187,9 @@ std::unordered_map PairwiseRootDomainMap::map( for (auto inx : c10::irange(out_size)) { IterDomain* producer_id = aligned_producer_ids.at(inx); IterDomain* consumer_id = consumer_root.at(inx); - if (producer_id == nullptr){ continue; } - if (!map_broadcast_ && - producer_id->isBroadcast() != consumer_id->isBroadcast()){ - continue; - } - if (!map_symbolic_ && - (producer_id->isSymbolic() || consumer_id->isSymbolic()) && - (!producer_id->extent()->sameAs(consumer_id->extent()))) { - continue; - } updatePairwiseRootDomainMap(producer_id, consumer_id); } @@ -187,8 +206,6 @@ std::unordered_map PairwiseRootDomainMap::map( // 2. IDs that may have different extents (e.g., non indexed // domains of torch_gather) // 3. Squeeze and unsqueeze - // 4. Broadcast and non broadcast - // 5. Symbolic ID with different extent from other ID // Condition 1: when the producer ID is the dim of a select-like op if (producer_id == indexed_producer_id) { @@ -233,38 +250,7 @@ std::unordered_map PairwiseRootDomainMap::map( continue; } - // Condition 4 - if (!map_broadcast_ && - producer_id->isBroadcast() != consumer_id->isBroadcast()) { - itc++; - itp++; - continue; - } - - // Condition 5 - // At least one ID is symbolic. - // - // If map_symbolic_ is true: - // Map these IDs regardless of other considerations. - // - // If map_symbolic_ is false (default): - // Map these only if their extents are identical. IterType::Symbolic - // reflects that the extent might evaluate to 1 for some inputs, in which - // case it may be valid to use those domains in a broadcast op. If the - // extents are exactly the same between two aligned IterDomains, the - // Symbolic one will be concretized to the same IterType as the other, so - // they should be mapped with one another. - if (!map_symbolic_ && - (producer_id->isSymbolic() || consumer_id->isSymbolic()) && - (!producer_id->extent()->sameAs(consumer_id->extent()))) { - itc++; - itp++; - continue; - } - - IterDomain* map_key_id = producer_id; - IterDomain* map_value_id = consumer_id; - updatePairwiseRootDomainMap(map_key_id, map_value_id); + updatePairwiseRootDomainMap(producer_id, consumer_id); itc++; itp++; From d05131132738f166166a03ee69389655f0f5a019 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 13 May 2024 23:57:31 +0000 Subject: [PATCH 32/35] bump version --- csrc/root_domain_map.cpp | 1 - version.txt | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 47e039014f0..ea44370b4b3 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include diff --git a/version.txt b/version.txt index 7179039691c..abd410582de 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.3 +0.2.4 From 41e2bbb8762701cde1493354a6e5a9e81cb8cbd4 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 14 May 2024 00:54:30 +0000 Subject: [PATCH 33/35] review comments --- csrc/ops/composite.h | 3 +++ csrc/scheduler/expr_eval_sched.cpp | 6 ------ csrc/scheduler/expr_eval_sched.h | 11 ++++++++--- csrc/scheduler/heuristic.h | 2 +- tests/python/pytest_input_generators.py | 2 +- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index 053553848d7..bcc4c8fdb5a 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -71,6 +71,9 @@ TensorView* leaky_relu(TensorView* x, Val* negative_slope); NVF_API TensorView* view_as_real(TensorView* x); +// Matmul function which takes in tensors with the shapes +// A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different layouts +// via strides. This has the same functionality as torch.matmul TensorView* matmul(TensorView* tv_a, TensorView* tv_b); } // namespace nvfuser diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 8ba08d3f164..90c11aca9d2 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -32,12 +32,6 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } -bool ExprEvalScheduler::canScheduleRunTime( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache) { - return true; -} void ExprEvalScheduler::schedule(Fusion* fusion) { fusion->aliasOutputToInput( diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h index 09731f77c0b..ed2473309ba 100644 --- a/csrc/scheduler/expr_eval_sched.h +++ b/csrc/scheduler/expr_eval_sched.h @@ -22,15 +22,20 @@ class ExprEvalScheduler : public SchedulerEntry { explicit ExprEvalScheduler( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr); + HeuristicSummary* data_cache = nullptr): SchedulerEntry(heuristicType()) { + params_ = std::make_shared("", runtime_info.getIndexType()); +} // This scheduler only accepts MatmulOp. static bool canScheduleCompileTime(Fusion* fusion); - static bool canScheduleRunTime( + + static bool ExprEvalScheduler::canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr); + HeuristicSummary* data_cache) { + return true; + } constexpr static ScheduleHeuristic heuristicType() { return ScheduleHeuristic::ExprEval; diff --git a/csrc/scheduler/heuristic.h b/csrc/scheduler/heuristic.h index db761bbbc0c..a0d73543cbc 100644 --- a/csrc/scheduler/heuristic.h +++ b/csrc/scheduler/heuristic.h @@ -30,7 +30,7 @@ class HeuristicParams : public PolymorphicBase { }; virtual bool sameAs(const std::shared_ptr& other) const { - if (!other->isA()) { + if (!other->isStrictlyA()) { return false; } return other->cparams == cparams; diff --git a/tests/python/pytest_input_generators.py b/tests/python/pytest_input_generators.py index a8c1702a3e3..bc1c070fe52 100644 --- a/tests/python/pytest_input_generators.py +++ b/tests/python/pytest_input_generators.py @@ -1511,7 +1511,7 @@ def matmul_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = ((B, M, K), (K,)), ((M, K), (K, N)), ((B, M, K), (B, K, N)), - ((B, B, M, K), (B, B, K, N)), + ((B, 1, M, K), (B, K, N)), ) for shape_a, shape_b in cases: From 6e3aa0ac512803c9afd70615ce1a371342037a20 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 14 May 2024 01:15:10 +0000 Subject: [PATCH 34/35] review comments --- csrc/ops/composite.cpp | 3 --- csrc/scheduler/expr_eval_sched.cpp | 9 --------- csrc/scheduler/expr_eval_sched.h | 2 +- 3 files changed, 1 insertion(+), 13 deletions(-) diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index d44153e8800..a151dbbb566 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -312,9 +312,6 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { } // namespace -// TODO (Priya): This will be renamed to matmul once we are ready to modify the -// python API backend. Keeping separate for now, to avoid breaking tests in -// Thunder. TensorView* matmul(TensorView* tv_a, TensorView* tv_b) { NVF_CHECK( tv_a->nDims() > 0 && tv_b->nDims() > 0, diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 90c11aca9d2..231cbac6489 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -13,14 +13,6 @@ namespace nvfuser { -ExprEvalScheduler::ExprEvalScheduler( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache) - : SchedulerEntry(heuristicType()) { - params_ = std::make_shared("", runtime_info.getIndexType()); -} - // Check if the fusion has a single MatmulOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); @@ -32,7 +24,6 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } - void ExprEvalScheduler::schedule(Fusion* fusion) { fusion->aliasOutputToInput( fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h index ed2473309ba..0e885528e7d 100644 --- a/csrc/scheduler/expr_eval_sched.h +++ b/csrc/scheduler/expr_eval_sched.h @@ -30,7 +30,7 @@ class ExprEvalScheduler : public SchedulerEntry { static bool canScheduleCompileTime(Fusion* fusion); - static bool ExprEvalScheduler::canScheduleRunTime( + static bool canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { From 5786f1391e247940c8b50812ea4756cbc8be9ea8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 14 May 2024 01:32:11 +0000 Subject: [PATCH 35/35] format, clangtidy --- csrc/ops/composite.h | 4 +- csrc/root_domain_map.cpp | 82 +++++++++++++------------ csrc/scheduler/expr_eval_sched.cpp | 9 +-- csrc/scheduler/expr_eval_sched.h | 15 ++--- csrc/scheduler/registry.cpp | 9 ++- tests/python/pytest_input_generators.py | 27 +++----- tests/python/pytest_opinfos.py | 2 +- 7 files changed, 72 insertions(+), 76 deletions(-) diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index bcc4c8fdb5a..fa617e75154 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -72,8 +72,8 @@ TensorView* leaky_relu(TensorView* x, Val* negative_slope); NVF_API TensorView* view_as_real(TensorView* x); // Matmul function which takes in tensors with the shapes -// A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different layouts -// via strides. This has the same functionality as torch.matmul +// A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different +// layouts via strides. This has the same functionality as torch.matmul TensorView* matmul(TensorView* tv_a, TensorView* tv_b); } // namespace nvfuser diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index ea44370b4b3..4e76e255a55 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include @@ -125,51 +124,54 @@ std::unordered_map PairwiseRootDomainMap::map( const auto& consumer_root = consumer->root(); // Check following conditions and add key-value iterdomain pair to domain map: - // 1. Do not map broadcast ID to non-broadcast ID unless map_broadcast_ = true. - // 2. Do not map Symbolic ID if the extents are not identical unless map_symbolic_ = true. - auto updatePairwiseRootDomainMap = - [&]( - IterDomain* producer_id, IterDomain* consumer_id) { - if (!map_broadcast_ && - producer_id->isBroadcast() != consumer_id->isBroadcast()){ - return; - } + // 1. Do not map broadcast ID to non-broadcast ID unless map_broadcast_ = + // true. + // 2. Do not map Symbolic ID if the extents are not identical unless + // map_symbolic_ = true. + auto updatePairwiseRootDomainMap = [&](IterDomain* producer_id, + IterDomain* consumer_id) { + if (!map_broadcast_ && + producer_id->isBroadcast() != consumer_id->isBroadcast()) { + return; + } - // Condition: At least one ID is symbolic. - // - // If map_symbolic_ is true: - // Map these IDs regardless of other considerations. - // - // If map_symbolic_ is false (default): - // Map these only if their extents are identical. IterType::Symbolic - // reflects that the extent might evaluate to 1 for some inputs, in which - // case it may be valid to use those domains in a broadcast op. If the - // extents are exactly the same between two aligned IterDomains, the - // Symbolic one will be concretized to the same IterType as the other, so - // they should be mapped with one another. - if (!map_symbolic_ && - (producer_id->isSymbolic() || consumer_id->isSymbolic()) && - (!producer_id->extent()->sameAs(consumer_id->extent()))) { - return; - } - - IterDomain* map_key_id = producer_id; - IterDomain* map_value_id = consumer_id; + // Condition: At least one ID is symbolic. + // + // If map_symbolic_ is true: + // Map these IDs regardless of other considerations. + // + // If map_symbolic_ is false (default): + // Map these only if their extents are identical. IterType::Symbolic + // reflects that the extent might evaluate to 1 for some inputs, in which + // case it may be valid to use those domains in a broadcast op. If the + // extents are exactly the same between two aligned IterDomains, the + // Symbolic one will be concretized to the same IterType as the other, so + // they should be mapped with one another. + if (!map_symbolic_ && + (producer_id->isSymbolic() || consumer_id->isSymbolic()) && + (!producer_id->extent()->sameAs(consumer_id->extent()))) { + return; + } - if (!producer_to_consumer) { - std::swap(map_key_id, map_value_id); - } + IterDomain* map_key_id = producer_id; + IterDomain* map_value_id = consumer_id; - if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { - dom_map.insert(std::make_pair(map_key_id, map_value_id)); - } - }; + if (!producer_to_consumer) { + std::swap(map_key_id, map_value_id); + } + + if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { + dom_map.insert(std::make_pair(map_key_id, map_value_id)); + } + }; // For MatmulOp, use the corresponding mapped input iterdomains. if (MatmulOp* op = dynamic_cast(consumer_tv_->definition())) { // Check if the producer is lhs/rhs input MatmulRole input_role = - producer->sameAs(op->inA()->as()->domain()) ? MatmulRole::INPUT_A : MatmulRole::INPUT_B; + producer->sameAs(op->inA()->as()->domain()) + ? MatmulRole::INPUT_A + : MatmulRole::INPUT_B; auto out_size = consumer_root.size(); // For MatmulOp, the input iterdomains at a given index do not necessarily @@ -186,12 +188,12 @@ std::unordered_map PairwiseRootDomainMap::map( for (auto inx : c10::irange(out_size)) { IterDomain* producer_id = aligned_producer_ids.at(inx); IterDomain* consumer_id = consumer_root.at(inx); - if (producer_id == nullptr){ + if (producer_id == nullptr) { continue; } updatePairwiseRootDomainMap(producer_id, consumer_id); } - + return dom_map; } diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 231cbac6489..c600b2f0bea 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -16,17 +16,18 @@ namespace nvfuser { // Check if the fusion has a single MatmulOp node bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { auto exprs = fusion->exprs(); - if (exprs.size() == 1 && exprs.front()->isA()){ + if (exprs.size() == 1 && exprs.front()->isA()) { return true; } scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), "Fusion must contain a single expression of type MatmulOp"); + heuristicType(), + "Fusion must contain a single expression of type MatmulOp"); return false; } void ExprEvalScheduler::schedule(Fusion* fusion) { fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); + fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); } -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h index 0e885528e7d..a3c99626501 100644 --- a/csrc/scheduler/expr_eval_sched.h +++ b/csrc/scheduler/expr_eval_sched.h @@ -16,20 +16,22 @@ class Fusion; class SchedulerRuntimeInfo; class HeuristicSummary; -// ExprEval scheduler represents the case where we allocate outputs directly using EE. No code is generated. +// ExprEval scheduler represents the case where we allocate outputs directly +// using EE. No code is generated. class ExprEvalScheduler : public SchedulerEntry { public: explicit ExprEvalScheduler( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr): SchedulerEntry(heuristicType()) { - params_ = std::make_shared("", runtime_info.getIndexType()); -} + HeuristicSummary* data_cache = nullptr) + : SchedulerEntry(heuristicType()) { + params_ = + std::make_shared("", runtime_info.getIndexType()); + } // This scheduler only accepts MatmulOp. static bool canScheduleCompileTime(Fusion* fusion); - static bool canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, @@ -44,5 +46,4 @@ class ExprEvalScheduler : public SchedulerEntry { void schedule(Fusion* fusion) override; }; - -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index eca8183b4d8..55448a29167 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -207,7 +207,8 @@ bool checkCanSchedule( return checkCanSchedule( fusion, runtime_info, data_cache); case ScheduleHeuristic::ExprEval: - return checkCanSchedule(fusion, runtime_info, data_cache); + return checkCanSchedule( + fusion, runtime_info, data_cache); default: NVF_ERROR(false, "unreachable"); return false; @@ -424,11 +425,9 @@ void HeuristicSummary::validate() const { entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO)); break; } + case ScheduleHeuristic::ExprEval: case ScheduleHeuristic::Matmul: { - // TODO: add a proper set of checks - break; - } - case ScheduleHeuristic::ExprEval: { + // TODO: add a proper set of checks for matmul break; } default: diff --git a/tests/python/pytest_input_generators.py b/tests/python/pytest_input_generators.py index bc1c070fe52..8237381d6ed 100644 --- a/tests/python/pytest_input_generators.py +++ b/tests/python/pytest_input_generators.py @@ -1487,7 +1487,10 @@ def vector_at_error_generator( make_arg(error_case["tensor_shape"]), index=error_case["index"] ), error_type, error_msg -def matmul_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs): + +def matmul_input_generator( + op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs +): make_arg = partial( make_tensor, dtype=dtype, @@ -1497,26 +1500,18 @@ def matmul_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = requires_grad=requires_grad, ) - B = 64 + B = 64 M = 512 N = 256 K = 32 - # shape_a, shape_b - cases = ( - ((K,), (K,)), - ((K,), (K, N)), - ((M, K), (K,)), - ((K,), (B, K, N)), - ((B, M, K), (K,)), - ((M, K), (K, N)), - ((B, M, K), (B, K, N)), - ((B, 1, M, K), (B, K, N)), - ) + shapes_a = ((K,), (M, K), (1, K), (B, M, K), (B, 1, M, K)) + shapes_b = ((K,), (K, N), (K, 1), (B, K, N)) - for shape_a, shape_b in cases: + for shape_a, shape_b in itertools.product(shapes_a, shapes_b): yield SampleInput(make_arg(shape_a), make_arg(shape_b)) + def linear_input_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): @@ -1543,6 +1538,4 @@ def multiply_range(maximum, step): for M, N, K in itertools.product(M, N, K): lhs_shape = (M, K) rhs_shape = (N, K) - yield ( - SampleInput(make_arg(lhs_shape), make_arg(rhs_shape), make_arg((N,))) - ) + yield (SampleInput(make_arg(lhs_shape), make_arg(rhs_shape), make_arg((N,)))) diff --git a/tests/python/pytest_opinfos.py b/tests/python/pytest_opinfos.py index 348da16d171..5d69f57891a 100644 --- a/tests/python/pytest_opinfos.py +++ b/tests/python/pytest_opinfos.py @@ -49,7 +49,7 @@ vector_at_error_generator, where_error_generator, matmul_input_generator, - linear_input_generator + linear_input_generator, ) from pytest_utils import ( bool_int_dtypes,