Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values)
makeUnaryOp(uop);
} else if (auto bop = dynamic_cast<BinaryOp*>(def)) {
makeBinaryOp(bop);
} else if (auto top = dynamic_cast<TernaryOp*>(def)) {
makeTernaryOp(top);
} else {
// There could be some ops not supported yet. For these ops, we will
// bind their outputs. So ignoring them here.
Expand Down Expand Up @@ -435,14 +437,36 @@ void NaiveValueMachine::makeBinaryOp(BinaryOp* bop) {
dest_[index] = out;
}

void NaiveValueMachine::makeTernaryOp(TernaryOp* top) {
int in0 = top->inputs()[0]->evaluatorIndex();
int in1 = top->inputs()[1]->evaluatorIndex();
int in2 = top->inputs()[2]->evaluatorIndex();
int out = top->outputs()[0]->evaluatorIndex();

NVF_ERROR(in0 >= 0, "Integer Machine: unknown first input: ", top);
NVF_ERROR(in1 >= 0, "Integer Machine: unknown second input: ", top);
NVF_ERROR(in2 >= 0, "Integer Machine: unknown third input: ", top);
NVF_ERROR(out >= 0, "Integer Machine: unknown out: ", top);

int index = makeInstructionEntry();
inst_type_[index] = InstructionType::TERNARY_OP;
top_type_[index] = top->getTernaryOpType();
src0_[index] = in0;
src1_[index] = in1;
src2_[index] = in2;
dest_[index] = out;
}

int NaiveValueMachine::makeInstructionEntry() {
int index = num_of_instructions_++;
inst_type_.emplace_back(InstructionType::UNARY_OP);
uop_type_.emplace_back(UnaryOpType::Abs);
bop_type_.emplace_back(BinaryOpType::Add);
top_type_.emplace_back(TernaryOpType::Where);
data_type_.emplace_back(DataType::Null);
src0_.emplace_back(-1);
src1_.emplace_back(-1);
src2_.emplace_back(-1);
dest_.emplace_back(-1);
return index;
}
Expand All @@ -459,6 +483,9 @@ void NaiveValueMachine::runInstruction(int index) {
case InstructionType::BINARY_OP:
runBinaryOp(index);
break;
case InstructionType::TERNARY_OP:
runTernaryOp(index);
break;
}
}

Expand Down Expand Up @@ -574,6 +601,70 @@ void NaiveValueMachine::runBinaryOp(int index) {
case BinaryOpType::Gcd:
dest = gcd(lhs, rhs);
break;
case BinaryOpType::LT:
dest = lhs < rhs;
break;
case BinaryOpType::LE:
dest = lhs <= rhs;
break;
case BinaryOpType::Eq:
dest = lhs == rhs;
break;
case BinaryOpType::NE:
dest = lhs != rhs;
break;
case BinaryOpType::GE:
dest = lhs >= rhs;
break;
case BinaryOpType::GT:
dest = lhs > rhs;
break;
default:
NVF_CHECK(false, "Unexpected operator type ", bop_type_[index]);
}

precomputed_values_.defined_[dest_index] = true;
}

void NaiveValueMachine::runTernaryOp(int index) {
using namespace PolymorphicValue_functions;
int src0_index = src0_[index];
int src1_index = src1_[index];
int src2_index = src2_[index];
bool src0_is_const = precomputed_values_.is_constant_[src0_index];
bool src1_is_const = precomputed_values_.is_constant_[src1_index];
bool src2_is_const = precomputed_values_.is_constant_[src2_index];

bool src_defined =
(precomputed_values_.defined_[src0_index] || src0_is_const) &&
(precomputed_values_.defined_[src1_index] || src1_is_const) &&
(precomputed_values_.defined_[src2_index] || src2_is_const);

if (!src_defined) {
return;
}
int dest_index = dest_[index];

auto& a = precomputed_values_.values_[src0_index];
auto& b = precomputed_values_.values_[src1_index];
auto& c = precomputed_values_.values_[src2_index];
auto& dest = precomputed_values_.values_[dest_index];

switch (top_type_[index]) {
case TernaryOpType::Clamp:
dest = std::min(std::max(a, b), c);
break;
case TernaryOpType::Lerp:
// This is the same lerp computed in helpers.cu
// https://math.stackexchange.com/a/1798323
dest = (c < 0.5) ? a + c * (b - a) : b - (b - a) * (1.0 - c);
break;
case TernaryOpType::Threshold:
dest = a <= b ? c : a;
break;
case TernaryOpType::Where:
dest = a ? b : c;
break;
default:
NVF_CHECK(!"Unexpected operator type");
}
Expand Down
23 changes: 18 additions & 5 deletions csrc/evaluator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ struct TensorArgAbstract;
//! PrecomputedValues that will provide the workspace
//! containing the concrete values for the values.
class NaiveValueMachine {
//! The generic types of instructions supported for this
//! machine, currently only binary and unary.
enum class InstructionType { UNARY_OP, BINARY_OP, SET_OP };
//! The generic types of instructions supported for this machine.
enum class InstructionType { UNARY_OP, BINARY_OP, TERNARY_OP, SET_OP };

public:
//! Constructor lowers all the expr IR nodes stored in precomputed_values
Expand All @@ -56,6 +55,9 @@ class NaiveValueMachine {
//! Convert an binary IR expr to an instruction
void makeBinaryOp(BinaryOp* bop);

//! Convert an ternary IR expr to an instruction
void makeTernaryOp(TernaryOp* bop);

//! Create an empty instruction with all default values
//! and place it at the end of the instruction buffer.
int makeInstructionEntry();
Expand All @@ -71,6 +73,9 @@ class NaiveValueMachine {
//! Runs a binary operation at given index of instruction buffer
void runBinaryOp(int index);

//! Runs a ternary operation at given index of instruction buffer
void runTernaryOp(int index);

private:
friend PrecomputedValues;

Expand Down Expand Up @@ -98,10 +103,14 @@ class NaiveValueMachine {
//! value at each index corresponding other ops.
std::vector<DataType> data_type_;

//! Unary operator type if applicable, contains a default
//! value at each index corresponding to a unary op.
//! Binary operator type if applicable, contains a default
//! value at each index corresponding to a binary op.
std::vector<BinaryOpType> bop_type_;

//! Ternary operator type if applicable, contains a default
//! value at each index corresponding to a ternary op.
std::vector<TernaryOpType> top_type_;

//! Indexes of operands and destination of each instruction.
//! The indexes corresponds to positions in the workspace
//! where concrete values are hosted.
Expand All @@ -113,6 +122,10 @@ class NaiveValueMachine {
//! each index corresponding to a unary op.
std::vector<int> src1_;

//! Operand 2 of each instruction, a default value at
//! each index corresponding to a unary or binary op.
std::vector<int> src2_;

//! Destination of each instruction.
std::vector<int> dest_;
};
Expand Down
19 changes: 15 additions & 4 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,23 @@ std::vector<PolymorphicValue> TernaryOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
using namespace PolymorphicValue_functions;
const auto& in1 = inputs.at(0);
const auto& in2 = inputs.at(1);
const auto& in3 = inputs.at(2);
const auto& a = inputs.at(0);
const auto& b = inputs.at(1);
const auto& c = inputs.at(2);
switch (getTernaryOpType()) {
case TernaryOpType::Clamp:
return {std::min(std::max(a, b), c)};
break;
case TernaryOpType::Lerp:
// This is the same lerp computed in helpers.cu
// https://math.stackexchange.com/a/1798323
return {(c < 0.5) ? a + c * (b - a) : b - (b - a) * (1.0 - c)};
break;
case TernaryOpType::Threshold:
return {(a <= b) ? c : a};
break;
case TernaryOpType::Where:
return {in1.as<bool>() ? in2 : in3};
return {a.as<bool>() ? b : c};
break;
default:
NVF_CHECK(
Expand Down
44 changes: 44 additions & 0 deletions test/test_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,4 +461,48 @@ TEST_F(ExprEvalTest, ReverseArray) {
EXPECT_EQ((std::vector<int64_t>)evaluator.evaluate(output), expect);
}

//! Test evaluating ternary ops
TEST_F(ExprEvalTest, TernaryOps) {
Fusion fusion;
FusionGuard fg(&fusion);

ExpressionEvaluator evaluator;

auto* a = IrBuilder::create<Val>(7.0);
auto* b = IrBuilder::create<Val>(3.8);
auto* c = IrBuilder::create<Val>(0.8);
auto* d = IrBuilder::create<Val>(0.2);
auto* t = IrBuilder::create<Val>(true);
auto* f = IrBuilder::create<Val>(false);

// Run once without PrecomputedValues, then once with
for ([[maybe_unused]] auto i : c10::irange(2)) {
EXPECT_EQ(evaluator.evaluate(clamp(b, c, a)), b->value());
EXPECT_EQ(evaluator.evaluate(clamp(a, c, b)), b->value());
EXPECT_EQ(evaluator.evaluate(clamp(d, c, b)), c->value());

EXPECT_EQ(
evaluator.evaluate(lerp(a, b, d)),
a->value() + d->value() * (b->value() - a->value()));

EXPECT_EQ(
evaluator.evaluate(lerp(a, b, c)),
a->value() + c->value() * (b->value() - a->value()));
EXPECT_EQ(
evaluator.evaluate(lerp(a, b, d)),
a->value() + d->value() * (b->value() - a->value()));

EXPECT_EQ(evaluator.evaluate(threshold(a, c, b)), a->value());
EXPECT_EQ(evaluator.evaluate(threshold(d, c, b)), b->value());
EXPECT_EQ(evaluator.evaluate(threshold(d, d, b)), b->value());

EXPECT_EQ(evaluator.evaluate(where(t, a, b)), a->value());
EXPECT_EQ(evaluator.evaluate(where(f, a, b)), b->value());

// Now bind a PrecomputedValues
PrecomputedValues pv(&fusion);
evaluator.bindPrecomputedValues(&pv);
}
}

} // namespace nvfuser