diff --git a/CMakeLists.txt b/CMakeLists.txt index 8795f167edc..b573b0498af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,6 +205,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp ${NVFUSER_SRCS_DIR}/scheduler/utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp + ${NVFUSER_SRCS_DIR}/scheduler/expr_eval_sched.cpp ${NVFUSER_SRCS_DIR}/serde/polymorphic_value.cpp ${NVFUSER_SRCS_DIR}/serde/utils.cpp ${NVFUSER_SRCS_DIR}/swizzle.cpp diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 294d48e829f..a151dbbb566 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -54,42 +54,6 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) { return dx; } -// This function will add a castOp to the output of the matrix multiplication -// The implementation of linear can use this but will skip the cast (set cast -// flag as false) and add the bias. -TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input) { - NVF_CHECK( - a->nDims() == b->nDims(), - "The number of dimension of A and B do not match"); - // TODO: We'll need to suppor nDims == 3 for bmm. - NVF_CHECK( - a->nDims() == 2, - "Only 2-D Tensors are supported, in the future we'll support 3-D as well!"); - - std::vector bcast_dims(a->nDims() + 1, false); - // A: [M, K, Bcast] - // B: [Bcast, K, N] - bcast_dims.at(bcast_dims.size() - 1) = true; - auto* tv0b = broadcast(a, bcast_dims); - bcast_dims.at(bcast_dims.size() - 1) = false; - bcast_dims.at(bcast_dims.size() - 3) = true; - auto* tv1b = broadcast(b, bcast_dims); - - NVF_CHECK( - a->getDataType().value() == b->getDataType().value(), - "data types of inputs to matmul don't match"); - auto* output = fusedMultiplySum(tv0b, tv1b, {-2}); - if (cast_output_to_input) { - // For matmul, the output dtype should match input. - return maybeCastOp(a->getDataType().value(), output); - } - return output; -} - -TensorView* matmul(TensorView* a, TensorView* b) { - return matmul(a, b, true /* cast output to input dtype */); -} - TensorView* linear(TensorView* a, TensorView* b, TensorView* bias) { // TODO: Support 1+ dimensional A. NVF_CHECK( @@ -348,10 +312,7 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { } // namespace -// TODO (Priya): This will be renamed to matmul once we are ready to modify the -// python API backend. Keeping separate for now, to avoid breaking tests in -// Thunder. -TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b) { +TensorView* matmul(TensorView* tv_a, TensorView* tv_b) { NVF_CHECK( tv_a->nDims() > 0 && tv_b->nDims() > 0, "Expected inputs to be atleast 1D, got: ", diff --git a/csrc/ops/composite.h b/csrc/ops/composite.h index 365b0144304..fa617e75154 100644 --- a/csrc/ops/composite.h +++ b/csrc/ops/composite.h @@ -47,16 +47,6 @@ NVF_API LstmResult lstm( TensorView* cell_x, TensorView* out_x); -// Matmul function which takes in tensors with the shapes -// A[M,K] B[K,N], but the tensors may have different layouts -// via strides. All restrictions from the matmul APIs also -// apply here. -TensorView* matmul(TensorView* a, TensorView* b); -// This second matmul function is not exposed via -// the Python interface, but it does the guts of the work and -// can be used to create mamtuls without a cast operation following it. -TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input); - // Linear functions which takes in two tensors of shapes A[M,K] and // B[N,K]. Takes in a options bias of shape [N] and performs // out = A * B_Transpose + bias. The output dtype matches the dtype @@ -81,6 +71,9 @@ TensorView* leaky_relu(TensorView* x, Val* negative_slope); NVF_API TensorView* view_as_real(TensorView* x); -TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b); +// Matmul function which takes in tensors with the shapes +// A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different +// layouts via strides. This has the same functionality as torch.matmul +TensorView* matmul(TensorView* tv_a, TensorView* tv_b); } // namespace nvfuser diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index b82213fea51..4e76e255a55 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -123,23 +123,55 @@ std::unordered_map PairwiseRootDomainMap::map( TensorDomain::noReductions(producer->maybeRFactor()); const auto& consumer_root = consumer->root(); - // Add key-value iterdomain pair to the map. - auto updatePairwiseRootDomainMap = - [&root_dims_to_map, producer_to_consumer, &dom_map]( - IterDomain* map_key_id, IterDomain* map_value_id) { - if (!producer_to_consumer) { - std::swap(map_key_id, map_value_id); - } - if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { - dom_map.insert(std::make_pair(map_key_id, map_value_id)); - } - }; + // Check following conditions and add key-value iterdomain pair to domain map: + // 1. Do not map broadcast ID to non-broadcast ID unless map_broadcast_ = + // true. + // 2. Do not map Symbolic ID if the extents are not identical unless + // map_symbolic_ = true. + auto updatePairwiseRootDomainMap = [&](IterDomain* producer_id, + IterDomain* consumer_id) { + if (!map_broadcast_ && + producer_id->isBroadcast() != consumer_id->isBroadcast()) { + return; + } + + // Condition: At least one ID is symbolic. + // + // If map_symbolic_ is true: + // Map these IDs regardless of other considerations. + // + // If map_symbolic_ is false (default): + // Map these only if their extents are identical. IterType::Symbolic + // reflects that the extent might evaluate to 1 for some inputs, in which + // case it may be valid to use those domains in a broadcast op. If the + // extents are exactly the same between two aligned IterDomains, the + // Symbolic one will be concretized to the same IterType as the other, so + // they should be mapped with one another. + if (!map_symbolic_ && + (producer_id->isSymbolic() || consumer_id->isSymbolic()) && + (!producer_id->extent()->sameAs(consumer_id->extent()))) { + return; + } + + IterDomain* map_key_id = producer_id; + IterDomain* map_value_id = consumer_id; + + if (!producer_to_consumer) { + std::swap(map_key_id, map_value_id); + } + + if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { + dom_map.insert(std::make_pair(map_key_id, map_value_id)); + } + }; // For MatmulOp, use the corresponding mapped input iterdomains. if (MatmulOp* op = dynamic_cast(consumer_tv_->definition())) { // Check if the producer is lhs/rhs input MatmulRole input_role = - producer->sameAs(op->inA()) ? MatmulRole::INPUT_A : MatmulRole::INPUT_B; + producer->sameAs(op->inA()->as()->domain()) + ? MatmulRole::INPUT_A + : MatmulRole::INPUT_B; auto out_size = consumer_root.size(); // For MatmulOp, the input iterdomains at a given index do not necessarily @@ -150,14 +182,18 @@ std::unordered_map PairwiseRootDomainMap::map( // input and output for index=2 // 2. `B, M, K] x [K, N] -> [B, M, N]`: For input B, the second iterdomain // maps to the third output iterdomain. - const std::vector& aligned_producer_id = + const std::vector& aligned_producer_ids = ops::mapMatmulOpIterDomains(producer_root, input_role, out_size); for (auto inx : c10::irange(out_size)) { - IterDomain* map_key_id = aligned_producer_id.at(inx); - IterDomain* map_value_id = consumer_root.at(inx); - updatePairwiseRootDomainMap(map_key_id, map_value_id); + IterDomain* producer_id = aligned_producer_ids.at(inx); + IterDomain* consumer_id = consumer_root.at(inx); + if (producer_id == nullptr) { + continue; + } + updatePairwiseRootDomainMap(producer_id, consumer_id); } + return dom_map; } @@ -171,8 +207,6 @@ std::unordered_map PairwiseRootDomainMap::map( // 2. IDs that may have different extents (e.g., non indexed // domains of torch_gather) // 3. Squeeze and unsqueeze - // 4. Broadcast and non broadcast - // 5. Symbolic ID with different extent from other ID // Condition 1: when the producer ID is the dim of a select-like op if (producer_id == indexed_producer_id) { @@ -217,38 +251,7 @@ std::unordered_map PairwiseRootDomainMap::map( continue; } - // Condition 4 - if (!map_broadcast_ && - producer_id->isBroadcast() != consumer_id->isBroadcast()) { - itc++; - itp++; - continue; - } - - // Condition 5 - // At least one ID is symbolic. - // - // If map_symbolic_ is true: - // Map these IDs regardless of other considerations. - // - // If map_symbolic_ is false (default): - // Map these only if their extents are identical. IterType::Symbolic - // reflects that the extent might evaluate to 1 for some inputs, in which - // case it may be valid to use those domains in a broadcast op. If the - // extents are exactly the same between two aligned IterDomains, the - // Symbolic one will be concretized to the same IterType as the other, so - // they should be mapped with one another. - if (!map_symbolic_ && - (producer_id->isSymbolic() || consumer_id->isSymbolic()) && - (!producer_id->extent()->sameAs(consumer_id->extent()))) { - itc++; - itp++; - continue; - } - - IterDomain* map_key_id = producer_id; - IterDomain* map_value_id = consumer_id; - updatePairwiseRootDomainMap(map_key_id, map_value_id); + updatePairwiseRootDomainMap(producer_id, consumer_id); itc++; itp++; diff --git a/csrc/scheduler/all_schedulers.h b/csrc/scheduler/all_schedulers.h index 5bdf013c2de..08a33343d6a 100644 --- a/csrc/scheduler/all_schedulers.h +++ b/csrc/scheduler/all_schedulers.h @@ -6,6 +6,7 @@ */ // clang-format on #pragma once +#include #include #include #include diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp new file mode 100644 index 00000000000..c600b2f0bea --- /dev/null +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -0,0 +1,33 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include + +namespace nvfuser { + +// Check if the fusion has a single MatmulOp node +bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { + auto exprs = fusion->exprs(); + if (exprs.size() == 1 && exprs.front()->isA()) { + return true; + } + scheduler_debug_utils::canScheduleRejectReason( + heuristicType(), + "Fusion must contain a single expression of type MatmulOp"); + return false; +} + +void ExprEvalScheduler::schedule(Fusion* fusion) { + fusion->aliasOutputToInput( + fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); +} + +} // namespace nvfuser diff --git a/csrc/scheduler/expr_eval_sched.h b/csrc/scheduler/expr_eval_sched.h new file mode 100644 index 00000000000..a3c99626501 --- /dev/null +++ b/csrc/scheduler/expr_eval_sched.h @@ -0,0 +1,49 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { + +class Fusion; +class SchedulerRuntimeInfo; +class HeuristicSummary; + +// ExprEval scheduler represents the case where we allocate outputs directly +// using EE. No code is generated. +class ExprEvalScheduler : public SchedulerEntry { + public: + explicit ExprEvalScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) + : SchedulerEntry(heuristicType()) { + params_ = + std::make_shared("", runtime_info.getIndexType()); + } + + // This scheduler only accepts MatmulOp. + static bool canScheduleCompileTime(Fusion* fusion); + + static bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + return true; + } + + constexpr static ScheduleHeuristic heuristicType() { + return ScheduleHeuristic::ExprEval; + } + + void schedule(Fusion* fusion) override; +}; + +} // namespace nvfuser diff --git a/csrc/scheduler/heuristic.h b/csrc/scheduler/heuristic.h index 92ba567fc41..a0d73543cbc 100644 --- a/csrc/scheduler/heuristic.h +++ b/csrc/scheduler/heuristic.h @@ -25,11 +25,20 @@ class HeuristicParams : public PolymorphicBase { return "Undefined Heuristic Params"; } - virtual size_t hash() const = 0; - - virtual bool sameAs(const std::shared_ptr& other) const = 0; + virtual size_t hash() const { + return 0; + }; + + virtual bool sameAs(const std::shared_ptr& other) const { + if (!other->isStrictlyA()) { + return false; + } + return other->cparams == cparams; + } - virtual std::shared_ptr clone() const = 0; + virtual std::shared_ptr clone() const { + return std::make_shared(); + } HeuristicParams() = default; HeuristicParams(std::string tag, PrimDataType index_type) diff --git a/csrc/scheduler/heuristic_types.cpp b/csrc/scheduler/heuristic_types.cpp index c85266685c7..3b1c5a7e97e 100644 --- a/csrc/scheduler/heuristic_types.cpp +++ b/csrc/scheduler/heuristic_types.cpp @@ -29,6 +29,8 @@ std::string toString(ScheduleHeuristic sh) { return "transpose"; case ScheduleHeuristic::Matmul: return "matmul"; + case ScheduleHeuristic::ExprEval: + return "expr_eval"; case ScheduleHeuristic::None: return "none"; default: diff --git a/csrc/scheduler/heuristic_types.h b/csrc/scheduler/heuristic_types.h index e30169141b4..bbc3949c8d6 100644 --- a/csrc/scheduler/heuristic_types.h +++ b/csrc/scheduler/heuristic_types.h @@ -55,11 +55,13 @@ enum class ScheduleHeuristic { InnerPersistent, InnerOuterPersistent, OuterPersistent, - Transpose + Transpose, + ExprEval }; //! Define a schedule table to loop over all the heuristics in priority order. -constexpr std::array all_heuristics_in_priority_order = { +constexpr std::array all_heuristics_in_priority_order = { + ScheduleHeuristic::ExprEval, ScheduleHeuristic::NoOp, ScheduleHeuristic::Matmul, ScheduleHeuristic::Reduction, diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index a9680ecb2f5..55448a29167 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -206,6 +206,9 @@ bool checkCanSchedule( case ScheduleHeuristic::Matmul: return checkCanSchedule( fusion, runtime_info, data_cache); + case ScheduleHeuristic::ExprEval: + return checkCanSchedule( + fusion, runtime_info, data_cache); default: NVF_ERROR(false, "unreachable"); return false; @@ -252,6 +255,10 @@ bool checkCanSchedule( scheduler_entry = std::make_unique(fusion, runtime_info, data_cache); break; + case ScheduleHeuristic::ExprEval: + scheduler_entry = + std::make_unique(fusion, runtime_info, data_cache); + break; default: NVF_ERROR(false, "unreachable"); } @@ -342,6 +349,9 @@ HeuristicSummary::HeuristicSummary( NVF_ERROR(canSchedule, "Could not schedule matmul (run time)"); break; } + case ScheduleHeuristic::ExprEval: + ExprEvalScheduler::canScheduleRunTime(fusion, runtime_info, this); + break; default: NVF_ERROR(false, "unknown heuristic"); } @@ -415,8 +425,9 @@ void HeuristicSummary::validate() const { entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO)); break; } + case ScheduleHeuristic::ExprEval: case ScheduleHeuristic::Matmul: { - // TODO: add a proper set of checks + // TODO: add a proper set of checks for matmul break; } default: diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index 11fe6c33f9b..e9c0b190f2c 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -417,7 +417,7 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { auto tv0 = makeConcreteTensor(a_shape, DataType::Half); auto tv1 = makeConcreteTensor(b_shape, DataType::Half); - auto tv2 = eagerMatmul(tv0, tv1); + auto tv2 = matmul(tv0, tv1); fusion->addInput(tv0); fusion->addInput(tv1); @@ -427,14 +427,8 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeConcrete) { at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); at::Tensor out_ref = at::matmul(t0, t1); - FusionExecutor fe; - fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); - fe.compileFusion(fusion.get(), {t0, t1}); - auto out = fe.runFusion({t0, t1}); - - // Verify that fusion compilation was skipped. - EXPECT_FALSE(fe.hasCompiledKernel()); + FusionExecutorCache fec(std::move(fusion)); + auto out = fec.runFusionWithInputs({t0, t1}); EXPECT_TRUE(at::allclose(out[0], out_ref)); } @@ -447,7 +441,7 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeSymbolic) { auto tv0 = makeSymbolicTensor(a_shape, DataType::Half); auto tv1 = makeSymbolicTensor(b_shape, DataType::Half); - auto tv2 = eagerMatmul(tv0, tv1); + auto tv2 = matmul(tv0, tv1); fusion->addInput(tv0); fusion->addInput(tv1); @@ -457,14 +451,8 @@ TEST_P(ATenNodesParametrizedTest, MatmulNodeSymbolic) { at::Tensor t1 = at::randn(b_shape, at::kHalf).cuda(); at::Tensor out_ref = at::matmul(t0, t1); - FusionExecutor fe; - fusion->aliasOutputToInput( - fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); - fe.compileFusion(fusion.get(), {t0, t1}); - auto out = fe.runFusion({t0, t1}); - - // Verify that fusion compilation was skipped. - EXPECT_FALSE(fe.hasCompiledKernel()); + FusionExecutorCache fec(std::move(fusion)); + auto out = fec.runFusionWithInputs({t0, t1}); EXPECT_TRUE(at::allclose(out[0], out_ref)); } diff --git a/tests/python/pytest_input_generators.py b/tests/python/pytest_input_generators.py index 53220008b90..8237381d6ed 100644 --- a/tests/python/pytest_input_generators.py +++ b/tests/python/pytest_input_generators.py @@ -1488,7 +1488,31 @@ def vector_at_error_generator( ), error_type, error_msg -def matmul_or_linear_input_generator( +def matmul_input_generator( + op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs +): + make_arg = partial( + make_tensor, + dtype=dtype, + device="cuda", + low=None, + high=None, + requires_grad=requires_grad, + ) + + B = 64 + M = 512 + N = 256 + K = 32 + + shapes_a = ((K,), (M, K), (1, K), (B, M, K), (B, 1, M, K)) + shapes_b = ((K,), (K, N), (K, 1), (B, K, N)) + + for shape_a, shape_b in itertools.product(shapes_a, shapes_b): + yield SampleInput(make_arg(shape_a), make_arg(shape_b)) + + +def linear_input_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): make_arg = partial( @@ -1507,17 +1531,11 @@ def multiply_range(maximum, step): map(pow, itertools.repeat(step, num_steps), range(1, num_steps + 1)) ) - is_linear = op.name == "linear" - # Ranges of tensor sizes: 8, 64, 512, 4096, 32768, ... # Use a Cartesian product to create a wide range of matrix shapes # I'll stop at 512 as possible numerical difference may show up. M, N, K = itertools.repeat(multiply_range(512, 8), 3) for M, N, K in itertools.product(M, N, K): lhs_shape = (M, K) - rhs_shape = (N, K) if is_linear else (K, N) - yield ( - SampleInput(make_arg(lhs_shape), make_arg(rhs_shape), make_arg((N,))) - if is_linear - else SampleInput(make_arg(lhs_shape), make_arg(rhs_shape)) - ) + rhs_shape = (N, K) + yield (SampleInput(make_arg(lhs_shape), make_arg(rhs_shape), make_arg((N,)))) diff --git a/tests/python/pytest_opinfos.py b/tests/python/pytest_opinfos.py index 740676f934b..5d69f57891a 100644 --- a/tests/python/pytest_opinfos.py +++ b/tests/python/pytest_opinfos.py @@ -48,7 +48,8 @@ var_mean_generator, vector_at_error_generator, where_error_generator, - matmul_or_linear_input_generator, + matmul_input_generator, + linear_input_generator, ) from pytest_utils import ( bool_int_dtypes, @@ -1115,7 +1116,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 else (torch.float16,) ), - sample_input_generator=matmul_or_linear_input_generator, + sample_input_generator=matmul_input_generator, reference=torch.matmul, ) matmul_ops.append(matmul_opinfo) @@ -1131,7 +1132,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 else (torch.float16,) ), - sample_input_generator=matmul_or_linear_input_generator, + sample_input_generator=linear_input_generator, reference=torch.nn.functional.linear, ) linear_ops.append(linear_opinfo) diff --git a/version.txt b/version.txt index 7179039691c..abd410582de 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.3 +0.2.4