From 119bcf0c234295b167cd2be83583e987f350d5aa Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 20 Jul 2021 12:38:01 -0700 Subject: [PATCH 1/6] handle upcasting case --- src/tir/op/op.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index af78804837ba..ef19a6457462 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -112,12 +112,23 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; - // Only do very simple type coversion + + // We keep casting pretty simple + // Two different floating point types will upconvert the lower bit floating point + // to the same type as the higher bit version. E.g. fp16 + fp32 --> fp32 + fp32. + // Furthermore: // int->float, DataType::Int(32)->int(64) // require the types to be relatively consistent // This will the reduce amount code generated by operators // and also help user to find potential type conversion problems. - if (!lhs.dtype().is_float() && + if (lhs.dtype().is_float() && rhs.dtype().is_float()) { + int max_num_bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); + if (lhs.dtype().bits() != max_num_bits) { + lhs = cast(rhs.dtype(), lhs); + } else { + rhs = cast(lhs.dtype(), rhs); + } + } else if (!lhs.dtype().is_float() && (rhs.dtype().is_float() || datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { // int->float From 67e16d56bc16865050428acb814e7e21d9dd33df Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 20 Jul 2021 14:14:15 -0700 Subject: [PATCH 2/6] test upcasting tests for tir --- tests/python/unittest/test_tir_base.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 6e081a179059..692b2348fb36 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -17,6 +17,8 @@ import tvm from tvm import tir from tvm.ir.transform import PassContext +import itertools +import numpy as np def build_tir_func(func): @@ -30,15 +32,25 @@ def build_tir_func(func): def test_scalar_add(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") - c = a + b - c = tir.ret(c) - c = tir.Evaluate(c) - func = tir.PrimFunc([a, b], c) - func = build_tir_func(func) - out = func(1.0, 2.0) - assert out == 3.0 + # All these types should be interchangeable with each other + # E.g. float16 + float32 upconverts the float16 --> float32 + # Meanwhile if an int or float or together the int will be + # cast to the float type. + lhs_types = ["float32", "float16", "int32", "int64"] + rhs_types = ["float32", "float16"] + for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types): + # Input vars should be float32, we will cast to test for upcasting between them + lhs_input = tir.Var("lhs", "float32") + rhs_input = tir.Var("rhs", "float32") + lhs = tir.Cast(lhs_type, lhs_input) + rhs = tir.Cast(rhs_type, rhs_input) + output = lhs + rhs + output = tir.ret(output) + output = tir.Evaluate(output) + func = tir.PrimFunc([lhs_input, rhs_input], output) + func = build_tir_func(func) + out = func(1.0, 2.0) + assert out == 3.0 def test_control_flow_jump(): From 6b6f7a111e6df09ce38c4509a1dda99549027c86 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 20 Jul 2021 14:47:19 -0700 Subject: [PATCH 3/6] address comaniac comments --- src/tir/op/op.cc | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index ef19a6457462..4c46580aea14 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -113,34 +113,29 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } if (lhs.dtype() == rhs.dtype()) return; - // We keep casting pretty simple - // Two different floating point types will upconvert the lower bit floating point - // to the same type as the higher bit version. E.g. fp16 + fp32 --> fp32 + fp32. - // Furthermore: - // int->float, DataType::Int(32)->int(64) - // require the types to be relatively consistent - // This will the reduce amount code generated by operators - // and also help user to find potential type conversion problems. + // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by operators. + // This can be helpful for users to find potential type conversion problems. The following are exceptions: if (lhs.dtype().is_float() && rhs.dtype().is_float()) { - int max_num_bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); - if (lhs.dtype().bits() != max_num_bits) { + // Given two dissimilar floats, cast the lower bit version to the higher bit version. + // E.g. fp16 + fp32 --> fp32 + fp32 + if (lhs.dtype().bits() < rhs.dtype().bits()) { lhs = cast(rhs.dtype(), lhs); - } else { + } else if (lhs.dtype().bits() > rhs.dtype().bits()) { rhs = cast(lhs.dtype(), rhs); } } else if (!lhs.dtype().is_float() && (rhs.dtype().is_float() || datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { - // int->float + // Cast int->float when the other operand is a float lhs = cast(rhs.dtype(), lhs); } else if ((lhs.dtype().is_float() || datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) && !rhs.dtype().is_float()) { - // int->float + // Cast int->float when the other operand is a float rhs = cast(lhs.dtype(), rhs); } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) || (lhs.dtype().is_uint() && rhs.dtype().is_uint())) { - // promote int to higher bits + // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (lhs.dtype().bits() < rhs.dtype().bits()) { lhs = cast(rhs.dtype(), lhs); } else { @@ -148,6 +143,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) || (lhs.dtype().is_uint() && rhs.dtype().is_int())) { + // Handle mixing signed and unsigned integers int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span); rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span); From dcb79f1983e640191b06801e64f30edbfaa8a570 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 21 Jul 2021 09:29:23 -0700 Subject: [PATCH 4/6] formatting --- src/tir/op/op.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 4c46580aea14..aca6d1b50b0e 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -113,8 +113,9 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } if (lhs.dtype() == rhs.dtype()) return; - // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by operators. - // This can be helpful for users to find potential type conversion problems. The following are exceptions: + // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by + // operators. This can be helpful for users to find potential type conversion problems. The + // following are exceptions: if (lhs.dtype().is_float() && rhs.dtype().is_float()) { // Given two dissimilar floats, cast the lower bit version to the higher bit version. // E.g. fp16 + fp32 --> fp32 + fp32 @@ -124,8 +125,8 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) rhs = cast(lhs.dtype(), rhs); } } else if (!lhs.dtype().is_float() && - (rhs.dtype().is_float() || - datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { + (rhs.dtype().is_float() || + datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { // Cast int->float when the other operand is a float lhs = cast(rhs.dtype(), lhs); } else if ((lhs.dtype().is_float() || From 1d8ff3e13d7aeab2c9c5931a632be34181996d74 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 27 Jul 2021 12:23:07 -0700 Subject: [PATCH 5/6] add negative tests --- tests/python/unittest/test_tir_base.py | 37 +++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 692b2348fb36..3928c817a125 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -16,9 +16,10 @@ # under the License. import tvm from tvm import tir +from tvm._ffi.base import TVMError from tvm.ir.transform import PassContext import itertools -import numpy as np +import pytest def build_tir_func(func): @@ -53,6 +54,40 @@ def test_scalar_add(): assert out == 3.0 +def assignment_helper(store_dtype, value_dtype): + store = tir.Var("store", dtype=store_dtype) + value = tir.Var("value", dtype=value_dtype) + tir.Let(store, value, body=store) + + +def test_fail_implicit_downcasts_same_type(): + # These lists should be sorted + bits = [8, 16, 32, 64] + for type in ["float", "int", "uint"]: + for i in range(len(bits) - 1): + with pytest.raises(TVMError): + assignment_helper( + store_dtype=f"{type}{bits[i]}", value_dtype=f"{type}{bits[i + 1]}" + ) + + +def test_cast_between_types(): + # We should only be able to assign values with the same types + bits = [16, 32] + types = ["float", "int", "uint"] + for store_type, store_bits, value_type, value_bits in itertools.product( + types, bits, types, bits + ): + store_dtype = f"{store_type}{store_bits}" + value_dtype = f"{value_type}{value_bits}" + if store_dtype == value_dtype: + assignment_helper(store_dtype, value_dtype) + else: + # TODO: we might want to allow casts between uint and int types + with pytest.raises(TVMError): + assignment_helper(store_dtype, value_dtype) + + def test_control_flow_jump(): ib = tvm.tir.ir_builder.create() a = tir.Var("a", "float32") From f3369f5611a391ffe467c22975bc1f5e500ace55 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 28 Jul 2021 10:44:37 -0700 Subject: [PATCH 6/6] fix failing test now allow other things --- tests/python/unittest/test_tir_ops.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index f1f8cf70d0c9..78eab6bdde9f 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -146,13 +146,22 @@ def verify_callop_float_only(f): rhs = te.var("rhs", dtype=rhs_dtype) if "float" not in lhs_dtype and "float" not in rhs_dtype: check_throws(lambda: f(lhs, rhs)) - elif "float" in lhs_dtype and "float" in rhs_dtype and lhs_dtype != rhs_dtype: - check_throws(lambda: f(lhs, rhs)) elif "float" in lhs_dtype: out = f(lhs, rhs) - assert out.dtype == lhs_dtype - assert out.args[0].dtype == lhs_dtype - assert out.args[1].dtype == lhs_dtype + + # Upcasting for floating point types + dtypes = [lhs_dtype, rhs_dtype] + if "float64" in dtypes: + target_dtype = "float64" + elif "float32" in dtypes: + target_dtype = "float32" + else: + target_dtype = "int32" + assert out.dtype == target_dtype + + # Final inputs are the right type + assert out.args[0].dtype == target_dtype + assert out.args[1].dtype == target_dtype else: out = f(lhs, rhs) assert out.dtype == rhs_dtype