diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index d93a7fde639a..507a6243cb0c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -478,6 +478,31 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N } } +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, CodeGenOpenCL* p) { + if (op->dtype.lanes() == 1) { + os << opstr << "(("; + p->PrintType(op->a->dtype, os); + os << ")"; + p->PrintExpr(op->a, os); + os << ", ("; + p->PrintType(op->b->dtype, os); + os << ")"; + p->PrintExpr(op->b, os); + os << ')'; + } else { + p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); + } +} + +void CodeGenOpenCL::VisitExpr_(const MinNode* op, std::ostream& os) { + PrintBinaryExpr(op, "min", os, this); +} + +void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { + PrintBinaryExpr(op, "max", os, this); +} + void CodeGenOpenCL::SetTextureScope( const std::unordered_map& scope) { // NOLINT(*) for (auto& texture : scope) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index a8c293c03056..8c36a817753c 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -65,6 +65,10 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const StoreNode* op) final; // NOLINT(*) + // overload min and max to avoid ambiguous call errors + void VisitExpr_(const MinNode* op, std::ostream& os) final; + void VisitExpr_(const MaxNode* op, std::ostream& os) final; + private: // whether enable fp16 and fp64 extension bool enable_fp16_{false}; diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 56392ec8cccc..2ac2ec9dd9e9 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -142,4 +142,5 @@ def check_erf(dev, n, dtype): if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan() + test_opencl_max() test_opencl_erf()