diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 3943c07ab82..0b7b6875ec6 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -191,15 +191,15 @@ const PolymorphicValue& ExpressionEvaluator::evaluate(ParallelType pt) { } const PolymorphicValue& ExpressionEvaluator::evaluate(const Val* value) { - return evaluateHelper(value, known_values_); + return evaluate(value, known_values_); } PolymorphicValue ExpressionEvaluator::evaluate(const Val* value) const { std::unordered_map known_values; - return evaluateHelper(value, known_values); + return evaluate(value, known_values); } -const PolymorphicValue& ExpressionEvaluator::evaluateHelper( +const PolymorphicValue& ExpressionEvaluator::evaluate( const Val* value, std::unordered_map& known_values) const { if (precomputed_values_ && precomputed_values_->ready()) { @@ -213,16 +213,7 @@ const PolymorphicValue& ExpressionEvaluator::evaluateHelper( if (!maybe_concrete_value.get().hasValue()) { if (auto def = value->definition()) { FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate"); - std::vector inputs; - inputs.reserve(def->inputs().size()); - for (auto i : def->inputs()) { - const auto& eval_i = evaluateHelper(i, known_values); - if (!eval_i.hasValue()) { - return null_; - } - inputs.emplace_back(eval_i); - } - auto outputs = def->evaluate(*this, inputs); + auto outputs = def->evaluate(*this, known_values); for (auto i : c10::irange(def->outputs().size())) { known_values[def->output(i)] = std::move(outputs[i]); } diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index 28f4cb2288c..f2a952e95f7 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -56,9 +56,16 @@ class ExpressionEvaluator { //! Try to evaluate a parallel dimension const PolymorphicValue& evaluate(ParallelType pt); - //! Try to evaluate a value using const evaluator ref + //! Evaluates a value through a const evaluator reference. + //! Initializes a known_values map to store intermediate values in lieu of + //! known_values_. NVF_API PolymorphicValue evaluate(const Val* value) const; + //! Base evaluate method called by other overloads and Expr::evaluate. + const PolymorphicValue& evaluate( + const Val* value, + std::unordered_map& known_values) const; + bool isKnown(const Val* value) const { return known_values_.count(value) > 0; } @@ -88,9 +95,6 @@ class ExpressionEvaluator { const Val* value, const std::unordered_map& additional_known_values) const; - const PolymorphicValue& evaluateHelper( - const Val* value, - std::unordered_map& known_values) const; private: // TODO: Consider make this const. It can't be const as bind() of diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 94779496564..a54a5f0d2c2 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -401,6 +401,21 @@ std::vector Expr::evaluate( "Please override the evaluate method"); } +std::vector Expr::evaluate( + const ExpressionEvaluator& ee, + std::unordered_map& known_values) const { + std::vector expr_inputs; + expr_inputs.reserve(inputs().size()); + for (auto inp : inputs()) { + const auto& eval_i = ee.evaluate(inp, known_values); + if (!eval_i.hasValue()) { + return {std::monostate{}}; + } + expr_inputs.emplace_back(eval_i); + } + return this->evaluate(ee, expr_inputs); +} + void Expr::addDataAttribute(PolymorphicValue attr) { addAttribute(IrBuilder::create(container(), std::move(attr))); } diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index a705791252f..1862726946d 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -524,6 +524,17 @@ class NVF_API Expr : public Statement { const ExpressionEvaluator& ee, const std::vector& inputs) const; + // This version allows evaluation of multiple ops together instead of one op + // at a time by overriding and skipping computation of intermediate inputs + // that are not required. For example: + // 1. CatOp is internally preceded by PadOp but the ATen evaluation uses only + // the unpadded inputs and the evaluation of padded inputs can be skipped. + // 2. Evaluating patterns in matmul fallback such as MmaOp + Cast/ MmaOp + + // Bias + Cast + virtual std::vector evaluate( + const ExpressionEvaluator& ee, + std::unordered_map& known_values) const; + // Input/output accessors const auto& inputs() const { return inputs_; diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index e2ac73720da..5cf9dc952c5 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2240,7 +2240,8 @@ class NVF_API CatOp : public Expr { std::string toInlineString(int indent_size = 0) const override; std::vector evaluate( const ExpressionEvaluator& ee, - const std::vector& inputs) const override; + std::unordered_map& known_values) + const override; int64_t concatenatedDim() const { return attribute(0); diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index ff4ccf09894..d8ae0b00ae1 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4459,14 +4459,20 @@ Val* CatOp::getPred(int input_idx) const { std::vector CatOp::evaluate( const ExpressionEvaluator& ee, - const std::vector& inputs) const { - std::vector in; + std::unordered_map& known_values) const { + // CatOp is preceded by a PadOp internally. + // For ATen evaluation, directly compute the unpadded inputs. + std::vector unpadded_inputs; + unpadded_inputs.reserve(inputs().size()); int64_t concat_dim = concatenatedDim(); - for (auto i : c10::irange(inputs.size())) { - auto unpadded_inp = ee.evaluate(input(i)->definition()->input(0)); - in.push_back(unpadded_inp.as()); + for (Val* inp : inputs()) { + NVF_CHECK( + inp->definition() != nullptr && inp->definition()->isA(), + "Expected CatOp to be preceded by a PadOp."); + auto eval_i = ee.evaluate(inp->definition()->input(0), known_values); + unpadded_inputs.push_back(eval_i.as()); } - return {at::cat(in, concat_dim)}; + return {at::cat(unpadded_inputs, concat_dim)}; } } // namespace nvfuser diff --git a/test/test_evaluator.cpp b/test/test_evaluator.cpp index eb0978bc9fa..20cc39f49b7 100644 --- a/test/test_evaluator.cpp +++ b/test/test_evaluator.cpp @@ -695,4 +695,33 @@ TEST_F(ExprEvalTest, SumDiv) { evaluator.evaluate(out); } +// Verify that the padded inputs are not evaluated +TEST_F(ExprEvalTest, CatOp) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + auto tv1 = makeContigTensor(2); + auto tv2 = cat({tv0, tv1}, 0); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({3, 2}, options); + auto t1 = at::randn({3, 2}, options); + + ExpressionEvaluator evaluator; + evaluator.bind(tv0, t0); + evaluator.bind(tv1, t1); + + at::Tensor out = evaluator.evaluate(tv2).as(); + + for (auto padded_in : tv2->definition()->inputs()) { + EXPECT_FALSE(evaluator.isKnown(padded_in)); + } + + EXPECT_TRUE(at::equal(out, at::cat({t0, t1}, 0))); +} + } // namespace nvfuser