Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ const PolymorphicValue& ExpressionEvaluator::evaluate(ParallelType pt) {
}

const PolymorphicValue& ExpressionEvaluator::evaluate(const Val* value) {
return evaluateHelper(value, known_values_);
return evaluate(value, known_values_);
}

PolymorphicValue ExpressionEvaluator::evaluate(const Val* value) const {
std::unordered_map<const Val*, PolymorphicValue> known_values;
return evaluateHelper(value, known_values);
return evaluate(value, known_values);
}

const PolymorphicValue& ExpressionEvaluator::evaluateHelper(
const PolymorphicValue& ExpressionEvaluator::evaluate(
const Val* value,
std::unordered_map<const Val*, PolymorphicValue>& known_values) const {
if (precomputed_values_ && precomputed_values_->ready()) {
Expand All @@ -213,16 +213,7 @@ const PolymorphicValue& ExpressionEvaluator::evaluateHelper(
if (!maybe_concrete_value.get().hasValue()) {
if (auto def = value->definition()) {
FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate");
std::vector<PolymorphicValue> inputs;
inputs.reserve(def->inputs().size());
for (auto i : def->inputs()) {
const auto& eval_i = evaluateHelper(i, known_values);
if (!eval_i.hasValue()) {
return null_;
}
inputs.emplace_back(eval_i);
}
auto outputs = def->evaluate(*this, inputs);
auto outputs = def->evaluate(*this, known_values);
for (auto i : c10::irange(def->outputs().size())) {
known_values[def->output(i)] = std::move(outputs[i]);
}
Expand Down
12 changes: 8 additions & 4 deletions csrc/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,16 @@ class ExpressionEvaluator {
//! Try to evaluate a parallel dimension
const PolymorphicValue& evaluate(ParallelType pt);

//! Try to evaluate a value using const evaluator ref
//! Evaluates a value through a const evaluator reference.
//! Initializes a known_values map to store intermediate values in lieu of
//! known_values_.
NVF_API PolymorphicValue evaluate(const Val* value) const;

//! Base evaluate method called by other overloads and Expr::evaluate.
const PolymorphicValue& evaluate(
const Val* value,
std::unordered_map<const Val*, PolymorphicValue>& known_values) const;

bool isKnown(const Val* value) const {
return known_values_.count(value) > 0;
}
Expand Down Expand Up @@ -88,9 +95,6 @@ class ExpressionEvaluator {
const Val* value,
const std::unordered_map<const Val*, PolymorphicValue>&
additional_known_values) const;
const PolymorphicValue& evaluateHelper(
const Val* value,
std::unordered_map<const Val*, PolymorphicValue>& known_values) const;

private:
// TODO: Consider make this const. It can't be const as bind() of
Expand Down
15 changes: 15 additions & 0 deletions csrc/ir/base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,21 @@ std::vector<PolymorphicValue> Expr::evaluate(
"Please override the evaluate method");
}

std::vector<PolymorphicValue> Expr::evaluate(
const ExpressionEvaluator& ee,
std::unordered_map<const Val*, PolymorphicValue>& known_values) const {
std::vector<PolymorphicValue> expr_inputs;
expr_inputs.reserve(inputs().size());
for (auto inp : inputs()) {
const auto& eval_i = ee.evaluate(inp, known_values);
if (!eval_i.hasValue()) {
return {std::monostate{}};
}
expr_inputs.emplace_back(eval_i);
}
return this->evaluate(ee, expr_inputs);
}

void Expr::addDataAttribute(PolymorphicValue attr) {
addAttribute(IrBuilder::create<Val>(container(), std::move(attr)));
}
Expand Down
11 changes: 11 additions & 0 deletions csrc/ir/base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,17 @@ class NVF_API Expr : public Statement {
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const;

// This version allows evaluation of multiple ops together instead of one op
// at a time by overriding and skipping computation of intermediate inputs
// that are not required. For example:
// 1. CatOp is internally preceded by PadOp but the ATen evaluation uses only
// the unpadded inputs and the evaluation of padded inputs can be skipped.
// 2. Evaluating patterns in matmul fallback such as MmaOp + Cast/ MmaOp +
// Bias + Cast
virtual std::vector<PolymorphicValue> evaluate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a code comment here about why this version of evaluate (and thus the added complexity) is needed and how it's different from the version above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Should be more clear now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix the format of the comment? It shows some unnecessary indentation.

const ExpressionEvaluator& ee,
std::unordered_map<const Val*, PolymorphicValue>& known_values) const;

// Input/output accessors
const auto& inputs() const {
return inputs_;
Expand Down
3 changes: 2 additions & 1 deletion csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2240,7 +2240,8 @@ class NVF_API CatOp : public Expr {
std::string toInlineString(int indent_size = 0) const override;
std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
std::unordered_map<const Val*, PolymorphicValue>& known_values)
const override;

int64_t concatenatedDim() const {
return attribute<int64_t>(0);
Expand Down
18 changes: 12 additions & 6 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4459,14 +4459,20 @@ Val* CatOp::getPred(int input_idx) const {

std::vector<PolymorphicValue> CatOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
std::vector<at::Tensor> in;
std::unordered_map<const Val*, PolymorphicValue>& known_values) const {
// CatOp is preceded by a PadOp internally.
// For ATen evaluation, directly compute the unpadded inputs.
std::vector<at::Tensor> unpadded_inputs;
unpadded_inputs.reserve(inputs().size());
int64_t concat_dim = concatenatedDim();
for (auto i : c10::irange(inputs.size())) {
auto unpadded_inp = ee.evaluate(input(i)->definition()->input(0));
in.push_back(unpadded_inp.as<at::Tensor>());
for (Val* inp : inputs()) {
NVF_CHECK(
inp->definition() != nullptr && inp->definition()->isA<PadOp>(),
"Expected CatOp to be preceded by a PadOp.");
auto eval_i = ee.evaluate(inp->definition()->input(0), known_values);
unpadded_inputs.push_back(eval_i.as<at::Tensor>());
}
return {at::cat(in, concat_dim)};
return {at::cat(unpadded_inputs, concat_dim)};
}

} // namespace nvfuser
29 changes: 29 additions & 0 deletions test/test_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,4 +695,33 @@ TEST_F(ExprEvalTest, SumDiv) {
evaluator.evaluate(out);
}

// Verify that the padded inputs are not evaluated
TEST_F(ExprEvalTest, CatOp) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
auto tv1 = makeContigTensor(2);
auto tv2 = cat({tv0, tv1}, 0);
fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({3, 2}, options);
auto t1 = at::randn({3, 2}, options);

ExpressionEvaluator evaluator;
evaluator.bind(tv0, t0);
evaluator.bind(tv1, t1);

at::Tensor out = evaluator.evaluate(tv2).as<at::Tensor>();

for (auto padded_in : tv2->definition()->inputs()) {
EXPECT_FALSE(evaluator.isKnown(padded_in));
}

EXPECT_TRUE(at::equal(out, at::cat({t0, t1}, 0)));
}

} // namespace nvfuser