diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index a0e19ca35cd9..7811e4debdbf 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -327,6 +327,10 @@ void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) { std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) { if (from == target) return value; + return CastTo(value, target); +} + +std::string CodeGenOpenCL::CastTo(std::string value, DataType target) { std::ostringstream os; if (target.lanes() == 1) { os << "(("; @@ -512,6 +516,40 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { PrintBinaryExpr(op, "max", os, this); } +void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { + std::ostringstream oss; + os << "("; + this->PrintExpr(op->a, oss); + os << CastTo(oss.str(), op->dtype); + oss.str(""); + os << " && "; + this->PrintExpr(op->b, oss); + os << CastTo(oss.str(), op->dtype); + os << ")"; +} + +void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { + std::ostringstream oss; + os << "("; + this->PrintExpr(op->a, oss); + os << CastTo(oss.str(), op->dtype); + oss.str(""); + os << " || "; + this->PrintExpr(op->b, oss); + os << CastTo(oss.str(), op->dtype); + os << ")"; +} + +void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { + os << "select("; + PrintExpr(op->false_value, os); + os << ", "; + PrintExpr(op->true_value, os); + os << ", "; + PrintExpr(op->condition, os); + os << ")"; +} + void CodeGenOpenCL::SetTextureScope( const std::unordered_map& scope) { // NOLINT(*) for (auto& texture : scope) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 3508eef43185..a7f4483ee2a9 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -55,6 +55,7 @@ class CodeGenOpenCL final : public CodeGenC { std::ostream& os); // NOLINT(*) void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + std::string CastTo(std::string value, DataType target); // NOLINT(*) void SetTextureScope(const std::unordered_map&); // NOLINT(*) // overload visitor @@ -69,6 +70,9 @@ class CodeGenOpenCL final : public CodeGenC { // overload min and max to avoid ambiguous call errors void VisitExpr_(const MinNode* op, std::ostream& os) final; void VisitExpr_(const MaxNode* op, std::ostream& os) final; + void VisitExpr_(const AndNode* op, std::ostream& os) final; + void VisitExpr_(const OrNode* op, std::ostream& os) final; + void VisitExpr_(const SelectNode* op, std::ostream& os) final; private: // whether enable fp16 and fp64 extension diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 2ac2ec9dd9e9..c25b3c2c86ea 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -139,8 +139,54 @@ def check_erf(dev, n, dtype): check_erf(dev, 1, "float64") +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl +def test_opencl_type_casting(): + def check_type_casting(ctx, n, dtype): + block_size = 4 + C = te.compute( + (n,), + lambda i: tvm.tir.Select( + tvm.tir.all( + *[ + i // block_size == tvm.tir.const(3, "int32"), + i % block_size == tvm.tir.const(3, "int32"), + ] + ), + tvm.tir.const(1, dtype), + tvm.tir.const(0, dtype), + ), + name="C", + ) + s = te.create_schedule(C.op) + (tx, vx) = s[C].split(s[C].op.axis[0], factor=block_size) + s[C].vectorize(vx) + thrx = te.thread_axis("threadIdx.x") + + s[C].bind(tx, thrx) + fun = tvm.build(s, [C], target) + + c = tvm.nd.empty((n,), dtype, ctx) + assembly = fun.imported_modules[0].get_source() + false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))" + true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))" + lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" + rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))" + cond = "({} && {})".format(lcond, rcond) + select = "select({}, {}, {})".format(false_branch, true_branch, cond) + count = assembly.count(select) + assert count == 1 + + fun(c) + + dev = tvm.device(target, 0) + + check_type_casting(dev, 16, "float32") + + if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan() test_opencl_max() test_opencl_erf() + test_opencl_type_casting()