From 8362b3204f2ba33d2714ffe21b210deb4dee41f5 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 11 May 2023 21:19:39 -0400 Subject: [PATCH 1/6] fix --- src/target/source/codegen_opencl.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index de96f923e2fa..14fbc4d8401c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -567,11 +567,7 @@ void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { os << ", "; PrintExpr(op->condition, oss); if (op->dtype.is_float()) { - if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) { - os << oss.str(); - } else { - os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); - } + os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); } else { os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype); } From 539f72b833164e083bc08b152e34c3ec38aff2db Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 13 May 2023 23:51:57 -0700 Subject: [PATCH 2/6] fix --- src/target/source/codegen_opencl.cc | 2 +- .../unittest/test_target_codegen_opencl.py | 35 +++++++++++++------ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 14fbc4d8401c..8db4044b85d2 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -437,7 +437,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { LOG(FATAL) << "Unsupported type: " << op->dtype << ", currently only float and half are supported for image2d OpenCL codegen."; } - this->PrintExpr(op->args[0], ss); + this->PrintExpr(op->args[0], ss); ss << ", "; ss << "image_sampler, "; ss << "((int2)("; diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index bc2d0a84fd9d..b89e1d257de5 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -168,20 +168,35 @@ def check_type_casting(ctx, n, dtype): c = tvm.nd.empty((n,), dtype, ctx) assembly = fun.imported_modules[0].get_source() - false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))" - true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))" - lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" - rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))" - cond = "({} && {})".format(lcond, rcond) - select = "select({}, {}, {})".format(false_branch, true_branch, cond) - count = assembly.count(select) - assert count == 1 - - fun(c) + + if dtype == "float32": + false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))" + true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))" + lcond = "convert_int4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" + rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))" + cond = "({} && {})".format(lcond, rcond) + select = "select({}, {}, {})".format(false_branch, true_branch, cond) + count = assembly.count(select) + assert count == 1 + + fun(c) + + elif dtype == "float16": + false_branch = "((half4)((half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f))" + true_branch = "((half4)((half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f))" + lcond = "convert_short4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))" + rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))))" + cond = "({} && {})".format(lcond, rcond) + select = "select({}, {}, {})".format(false_branch, true_branch, cond) + count = assembly.count(select) + assert count == 1 + + fun(c) dev = tvm.device(target, 0) check_type_casting(dev, 16, "float32") + check_type_casting(dev, 16, "float16") if __name__ == "__main__": From e3d9c5c44a59c1abba087923d121b93419fc05c2 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 14 May 2023 09:38:41 -0400 Subject: [PATCH 3/6] Update test_target_codegen_opencl.py --- tests/python/unittest/test_target_codegen_opencl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index b89e1d257de5..dcfa14755560 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -196,7 +196,8 @@ def check_type_casting(ctx, n, dtype): dev = tvm.device(target, 0) check_type_casting(dev, 16, "float32") - check_type_casting(dev, 16, "float16") + # fp16 is not yet supported in ci + # check_type_casting(dev, 16, "float16") if __name__ == "__main__": From 5741a59858fb6b99be82c5762a8115a02b485563 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 14 May 2023 09:41:24 -0400 Subject: [PATCH 4/6] Update test_target_codegen_opencl.py --- tests/python/unittest/test_target_codegen_opencl.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index dcfa14755560..edd0f649f4d7 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -168,7 +168,7 @@ def check_type_casting(ctx, n, dtype): c = tvm.nd.empty((n,), dtype, ctx) assembly = fun.imported_modules[0].get_source() - + if dtype == "float32": false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))" true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))" @@ -178,9 +178,8 @@ def check_type_casting(ctx, n, dtype): select = "select({}, {}, {})".format(false_branch, true_branch, cond) count = assembly.count(select) assert count == 1 - fun(c) - + elif dtype == "float16": false_branch = "((half4)((half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f))" true_branch = "((half4)((half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f))" @@ -190,7 +189,6 @@ def check_type_casting(ctx, n, dtype): select = "select({}, {}, {})".format(false_branch, true_branch, cond) count = assembly.count(select) assert count == 1 - fun(c) dev = tvm.device(target, 0) From b76539c5b4144261826247082e65fc3ee89356fc Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 14 May 2023 11:55:29 -0400 Subject: [PATCH 5/6] Update test_target_codegen_opencl.py --- tests/python/unittest/test_target_codegen_opencl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index edd0f649f4d7..67dc37363ea9 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -189,7 +189,7 @@ def check_type_casting(ctx, n, dtype): select = "select({}, {}, {})".format(false_branch, true_branch, cond) count = assembly.count(select) assert count == 1 - fun(c) + fun(c) dev = tvm.device(target, 0) From 9528148260a61ba6fbe0b76380bdfc90afb49853 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 14 May 2023 12:39:58 -0400 Subject: [PATCH 6/6] Update codegen_opencl.cc --- src/target/source/codegen_opencl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 8db4044b85d2..14fbc4d8401c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -437,7 +437,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { LOG(FATAL) << "Unsupported type: " << op->dtype << ", currently only float and half are supported for image2d OpenCL codegen."; } - this->PrintExpr(op->args[0], ss); + this->PrintExpr(op->args[0], ss); ss << ", "; ss << "image_sampler, "; ss << "((int2)(";