Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a695299
add matmul node, create output tensorview
Priya2698 Apr 25, 2024
961ea76
wip scheduler
Priya2698 Apr 25, 2024
fe23404
add to dispatch, scheduler heuristic, registry
Priya2698 Apr 26, 2024
075a83f
root map override
Priya2698 Apr 29, 2024
a7a15a8
rebase
Priya2698 May 1, 2024
15b9062
mapping for matmul ir node
Priya2698 May 1, 2024
44d13c4
use higher dim inp to create output
Priya2698 May 1, 2024
0cf2e6b
remove scheduler"
Priya2698 May 2, 2024
bb52526
review comments
Priya2698 May 2, 2024
6c130d7
modify pairwise matching
Priya2698 May 2, 2024
6bb6031
modify matmul out allocation
Priya2698 May 4, 2024
b75d72f
move mapping logic to another function
Priya2698 May 6, 2024
54314bf
use mapping in root domain
Priya2698 May 6, 2024
22031f2
comment
Priya2698 May 6, 2024
76fd527
add dot product case
Priya2698 May 6, 2024
dc45b68
format
Priya2698 May 6, 2024
5cb9fe5
1D case, review comments
Priya2698 May 7, 2024
2047629
move common code
Priya2698 May 7, 2024
bc86019
format
Priya2698 May 7, 2024
6d14c32
review comments
Priya2698 May 9, 2024
9910676
lin
Priya2698 May 9, 2024
b9613f0
add to scheduler
Priya2698 May 2, 2024
6a059ab
check for matmul op
Priya2698 May 7, 2024
5b9a4a3
Update csrc/scheduler/expr_eval_sched.cpp
Priya2698 May 9, 2024
1ad32ea
remove unused functions, add defaults to heuristic param
Priya2698 May 9, 2024
a1b22e2
change scheduler order
Priya2698 May 10, 2024
bcd5791
fix comparison
Priya2698 May 13, 2024
7183408
chech broadcast and symbolic conditions
Priya2698 May 13, 2024
ffd48df
rename API
Priya2698 May 13, 2024
8f80548
modify matmul generator to use cases from Thunder
Priya2698 May 13, 2024
a11c4a9
refactor code
Priya2698 May 13, 2024
d051311
bump version
Priya2698 May 13, 2024
41e2bbb
review comments
Priya2698 May 14, 2024
6e3aa0a
review comments
Priya2698 May 14, 2024
5786f13
format, clangtidy
Priya2698 May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp
${NVFUSER_SRCS_DIR}/scheduler/utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp
${NVFUSER_SRCS_DIR}/scheduler/expr_eval_sched.cpp
${NVFUSER_SRCS_DIR}/serde/polymorphic_value.cpp
${NVFUSER_SRCS_DIR}/serde/utils.cpp
${NVFUSER_SRCS_DIR}/swizzle.cpp
Expand Down
41 changes: 1 addition & 40 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,42 +54,6 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) {
return dx;
}

// This function will add a castOp to the output of the matrix multiplication
// The implementation of linear can use this but will skip the cast (set cast
// flag as false) and add the bias.
TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input) {
NVF_CHECK(
a->nDims() == b->nDims(),
"The number of dimension of A and B do not match");
// TODO: We'll need to suppor nDims == 3 for bmm.
NVF_CHECK(
a->nDims() == 2,
"Only 2-D Tensors are supported, in the future we'll support 3-D as well!");

std::vector<bool> bcast_dims(a->nDims() + 1, false);
// A: [M, K, Bcast]
// B: [Bcast, K, N]
bcast_dims.at(bcast_dims.size() - 1) = true;
auto* tv0b = broadcast(a, bcast_dims);
bcast_dims.at(bcast_dims.size() - 1) = false;
bcast_dims.at(bcast_dims.size() - 3) = true;
auto* tv1b = broadcast(b, bcast_dims);

NVF_CHECK(
a->getDataType().value() == b->getDataType().value(),
"data types of inputs to matmul don't match");
auto* output = fusedMultiplySum(tv0b, tv1b, {-2});
if (cast_output_to_input) {
// For matmul, the output dtype should match input.
return maybeCastOp(a->getDataType().value(), output);
}
return output;
}

TensorView* matmul(TensorView* a, TensorView* b) {
return matmul(a, b, true /* cast output to input dtype */);
}

TensorView* linear(TensorView* a, TensorView* b, TensorView* bias) {
// TODO: Support 1+ dimensional A.
NVF_CHECK(
Expand Down Expand Up @@ -348,10 +312,7 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) {

} // namespace

// TODO (Priya): This will be renamed to matmul once we are ready to modify the
// python API backend. Keeping separate for now, to avoid breaking tests in
// Thunder.
TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b) {
TensorView* matmul(TensorView* tv_a, TensorView* tv_b) {
NVF_CHECK(
tv_a->nDims() > 0 && tv_b->nDims() > 0,
"Expected inputs to be atleast 1D, got: ",
Expand Down
15 changes: 4 additions & 11 deletions csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,6 @@ NVF_API LstmResult lstm(
TensorView* cell_x,
TensorView* out_x);

// Matmul function which takes in tensors with the shapes
// A[M,K] B[K,N], but the tensors may have different layouts
// via strides. All restrictions from the matmul APIs also
// apply here.
TensorView* matmul(TensorView* a, TensorView* b);
// This second matmul function is not exposed via
// the Python interface, but it does the guts of the work and
// can be used to create mamtuls without a cast operation following it.
TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input);

// Linear functions which takes in two tensors of shapes A[M,K] and
// B[N,K]. Takes in a options bias of shape [N] and performs
// out = A * B_Transpose + bias. The output dtype matches the dtype
Expand All @@ -81,6 +71,9 @@ TensorView* leaky_relu(TensorView* x, Val* negative_slope);

NVF_API TensorView* view_as_real(TensorView* x);

TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b);
// Matmul function which takes in tensors with the shapes
// A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different
// layouts via strides. This has the same functionality as torch.matmul
TensorView* matmul(TensorView* tv_a, TensorView* tv_b);

} // namespace nvfuser
103 changes: 53 additions & 50 deletions csrc/root_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,55 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
TensorDomain::noReductions(producer->maybeRFactor());
const auto& consumer_root = consumer->root();

// Add key-value iterdomain pair to the map.
auto updatePairwiseRootDomainMap =
[&root_dims_to_map, producer_to_consumer, &dom_map](
IterDomain* map_key_id, IterDomain* map_value_id) {
if (!producer_to_consumer) {
std::swap(map_key_id, map_value_id);
}
if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
dom_map.insert(std::make_pair(map_key_id, map_value_id));
}
};
// Check following conditions and add key-value iterdomain pair to domain map:
// 1. Do not map broadcast ID to non-broadcast ID unless map_broadcast_ =
// true.
// 2. Do not map Symbolic ID if the extents are not identical unless
// map_symbolic_ = true.
auto updatePairwiseRootDomainMap = [&](IterDomain* producer_id,
IterDomain* consumer_id) {
if (!map_broadcast_ &&
producer_id->isBroadcast() != consumer_id->isBroadcast()) {
return;
}

// Condition: At least one ID is symbolic.
//
// If map_symbolic_ is true:
// Map these IDs regardless of other considerations.
//
// If map_symbolic_ is false (default):
// Map these only if their extents are identical. IterType::Symbolic
// reflects that the extent might evaluate to 1 for some inputs, in which
// case it may be valid to use those domains in a broadcast op. If the
// extents are exactly the same between two aligned IterDomains, the
// Symbolic one will be concretized to the same IterType as the other, so
// they should be mapped with one another.
if (!map_symbolic_ &&
(producer_id->isSymbolic() || consumer_id->isSymbolic()) &&
(!producer_id->extent()->sameAs(consumer_id->extent()))) {
return;
}

IterDomain* map_key_id = producer_id;
IterDomain* map_value_id = consumer_id;

if (!producer_to_consumer) {
std::swap(map_key_id, map_value_id);
}

if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
dom_map.insert(std::make_pair(map_key_id, map_value_id));
}
};

// For MatmulOp, use the corresponding mapped input iterdomains.
if (MatmulOp* op = dynamic_cast<MatmulOp*>(consumer_tv_->definition())) {
// Check if the producer is lhs/rhs input
MatmulRole input_role =
producer->sameAs(op->inA()) ? MatmulRole::INPUT_A : MatmulRole::INPUT_B;
producer->sameAs(op->inA()->as<TensorView>()->domain())
? MatmulRole::INPUT_A
: MatmulRole::INPUT_B;
auto out_size = consumer_root.size();

// For MatmulOp, the input iterdomains at a given index do not necessarily
Expand All @@ -150,14 +182,18 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
// input and output for index=2
// 2. `B, M, K] x [K, N] -> [B, M, N]`: For input B, the second iterdomain
// maps to the third output iterdomain.
const std::vector<IterDomain*>& aligned_producer_id =
const std::vector<IterDomain*>& aligned_producer_ids =
ops::mapMatmulOpIterDomains(producer_root, input_role, out_size);

for (auto inx : c10::irange(out_size)) {
IterDomain* map_key_id = aligned_producer_id.at(inx);
IterDomain* map_value_id = consumer_root.at(inx);
updatePairwiseRootDomainMap(map_key_id, map_value_id);
IterDomain* producer_id = aligned_producer_ids.at(inx);
IterDomain* consumer_id = consumer_root.at(inx);
if (producer_id == nullptr) {
continue;
}
updatePairwiseRootDomainMap(producer_id, consumer_id);
}

return dom_map;
}

Expand All @@ -171,8 +207,6 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
// 2. IDs that may have different extents (e.g., non indexed
// domains of torch_gather)
// 3. Squeeze and unsqueeze
// 4. Broadcast and non broadcast
// 5. Symbolic ID with different extent from other ID

// Condition 1: when the producer ID is the dim of a select-like op
if (producer_id == indexed_producer_id) {
Expand Down Expand Up @@ -217,38 +251,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
continue;
}

// Condition 4
if (!map_broadcast_ &&
producer_id->isBroadcast() != consumer_id->isBroadcast()) {
itc++;
itp++;
continue;
}

// Condition 5
// At least one ID is symbolic.
//
// If map_symbolic_ is true:
// Map these IDs regardless of other considerations.
//
// If map_symbolic_ is false (default):
// Map these only if their extents are identical. IterType::Symbolic
// reflects that the extent might evaluate to 1 for some inputs, in which
// case it may be valid to use those domains in a broadcast op. If the
// extents are exactly the same between two aligned IterDomains, the
// Symbolic one will be concretized to the same IterType as the other, so
// they should be mapped with one another.
if (!map_symbolic_ &&
(producer_id->isSymbolic() || consumer_id->isSymbolic()) &&
(!producer_id->extent()->sameAs(consumer_id->extent()))) {
itc++;
itp++;
continue;
}

IterDomain* map_key_id = producer_id;
IterDomain* map_value_id = consumer_id;
updatePairwiseRootDomainMap(map_key_id, map_value_id);
updatePairwiseRootDomainMap(producer_id, consumer_id);

itc++;
itp++;
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/all_schedulers.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on
#pragma once
#include <scheduler/expr_eval_sched.h>
#include <scheduler/matmul.h>
#include <scheduler/no_op.h>
#include <scheduler/normalization_inner.h>
Expand Down
33 changes: 33 additions & 0 deletions csrc/scheduler/expr_eval_sched.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <ir/utils.h>
#include <scheduler/debug_utils.h>
#include <scheduler/expr_eval_sched.h>
#include <scheduler/registry_utils.h>

namespace nvfuser {

// Check if the fusion has a single MatmulOp node
bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) {
auto exprs = fusion->exprs();
if (exprs.size() == 1 && exprs.front()->isA<MatmulOp>()) {
return true;
}
scheduler_debug_utils::canScheduleRejectReason(
heuristicType(),
"Fusion must contain a single expression of type MatmulOp");
return false;
}

void ExprEvalScheduler::schedule(Fusion* fusion) {
fusion->aliasOutputToInput(
fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate);
}

} // namespace nvfuser
49 changes: 49 additions & 0 deletions csrc/scheduler/expr_eval_sched.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <scheduler/heuristic.h>
#include <scheduler/registry.h>

namespace nvfuser {

class Fusion;
class SchedulerRuntimeInfo;
class HeuristicSummary;

// ExprEval scheduler represents the case where we allocate outputs directly
// using EE. No code is generated.
class ExprEvalScheduler : public SchedulerEntry {
public:
explicit ExprEvalScheduler(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(heuristicType()) {
params_ =
std::make_shared<HeuristicParams>("", runtime_info.getIndexType());
}

// This scheduler only accepts MatmulOp.
static bool canScheduleCompileTime(Fusion* fusion);

static bool canScheduleRunTime(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache) {
return true;
}

constexpr static ScheduleHeuristic heuristicType() {
return ScheduleHeuristic::ExprEval;
}

void schedule(Fusion* fusion) override;
};

} // namespace nvfuser
17 changes: 13 additions & 4 deletions csrc/scheduler/heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ class HeuristicParams : public PolymorphicBase {
return "Undefined Heuristic Params";
}

virtual size_t hash() const = 0;

virtual bool sameAs(const std::shared_ptr<HeuristicParams>& other) const = 0;
virtual size_t hash() const {
return 0;
};

virtual bool sameAs(const std::shared_ptr<HeuristicParams>& other) const {
if (!other->isStrictlyA<HeuristicParams>()) {
return false;
}
return other->cparams == cparams;
}

virtual std::shared_ptr<HeuristicParams> clone() const = 0;
virtual std::shared_ptr<HeuristicParams> clone() const {
return std::make_shared<HeuristicParams>();
}

HeuristicParams() = default;
HeuristicParams(std::string tag, PrimDataType index_type)
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/heuristic_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ std::string toString(ScheduleHeuristic sh) {
return "transpose";
case ScheduleHeuristic::Matmul:
return "matmul";
case ScheduleHeuristic::ExprEval:
return "expr_eval";
case ScheduleHeuristic::None:
return "none";
default:
Expand Down
6 changes: 4 additions & 2 deletions csrc/scheduler/heuristic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ enum class ScheduleHeuristic {
InnerPersistent,
InnerOuterPersistent,
OuterPersistent,
Transpose
Transpose,
ExprEval
};

//! Define a schedule table to loop over all the heuristics in priority order.
constexpr std::array<ScheduleHeuristic, 8> all_heuristics_in_priority_order = {
constexpr std::array<ScheduleHeuristic, 9> all_heuristics_in_priority_order = {
ScheduleHeuristic::ExprEval,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should NoOp come before ExprEval?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some cases get accepted by NoOp scheduler, which is why I prioritized ExprEval scheduler.

We may need to change the heuristics of NoOp if we want to switch the order.

Copy link
Collaborator

@jacobhinkle jacobhinkle May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! Thanks for mentioning that. Does NoOp scheduler accept the cases where you have a single scalar output? Because it seems to me that it would do so based on this code:

Fuser/csrc/fusion.cpp

Lines 341 to 359 in 8c18701

bool Fusion::isNoOp() {
if (exprs().empty()) {
return true;
}
for (auto out_tv : ir_utils::filterByType<TensorView>(outputs())) {
const std::vector<IterDomain*>& root_dom =
TensorDomain::noReductions(out_tv->getMaybeRFactorDomain());
const bool size_zero =
std::any_of(root_dom.begin(), root_dom.end(), [](IterDomain* id) {
return id->extent()->isConstScalar() && id->extent()->evaluate() == 0;
});
if (!size_zero) {
return false;
}
}
return true;
}

We should add a special case for zero-dimensional outputs there On second look, it seems like size_zero would be false in the case that root_dom.empty(). However, the code below might not properly handle zero-dimensional outputs:
// Check that all outputs are either broadcast or ignored reduction.
for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
auto concrete_dimension = TensorDomain::noReductions(
TensorDomain::noBroadcasts(out_tv->getLeafDomain()));
if (!concrete_dimension.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
heuristicType(), "output has a concrete dimension");
return false;
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[  FAILED  ] 6 tests, listed below:
[  FAILED  ] ATenNodesParametrizedTest.MatmulNodeConcrete/2, where GetParam() = ({ 32 }, { 32, 1 })
[  FAILED  ] ATenNodesParametrizedTest.MatmulNodeConcrete/8, where GetParam() = ({ 1, 32 }, { 32 })
[  FAILED  ] ATenNodesParametrizedTest.MatmulNodeConcrete/10, where GetParam() = ({ 1, 32 }, { 32, 1 })
[  FAILED  ] ATenNodesParametrizedTest.MatmulNodeSymbolic/2, where GetParam() = ({ 32 }, { 32, 1 })
[  FAILED  ] ATenNodesParametrizedTest.MatmulNodeSymbolic/8, where GetParam() = ({ 1, 32 }, { 32 })
[  FAILED  ] ATenNodesParametrizedTest.MatmulNodeSymbolic/10, where GetParam() = ({ 1, 32 }, { 32, 1 })

It is likely because there no reductions identified since we use ATen, and all the dimensions in the output are broadcast dimensions. So the cases where M/N = 1 get picked by NoOp

ScheduleHeuristic::NoOp,
ScheduleHeuristic::Matmul,
ScheduleHeuristic::Reduction,
Expand Down
Loading