diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index af78804837ba..aca6d1b50b0e 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -112,24 +112,31 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; - // Only do very simple type coversion - // int->float, DataType::Int(32)->int(64) - // require the types to be relatively consistent - // This will the reduce amount code generated by operators - // and also help user to find potential type conversion problems. - if (!lhs.dtype().is_float() && - (rhs.dtype().is_float() || - datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { - // int->float + + // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by + // operators. This can be helpful for users to find potential type conversion problems. The + // following are exceptions: + if (lhs.dtype().is_float() && rhs.dtype().is_float()) { + // Given two dissimilar floats, cast the lower bit version to the higher bit version. + // E.g. fp16 + fp32 --> fp32 + fp32 + if (lhs.dtype().bits() < rhs.dtype().bits()) { + lhs = cast(rhs.dtype(), lhs); + } else if (lhs.dtype().bits() > rhs.dtype().bits()) { + rhs = cast(lhs.dtype(), rhs); + } + } else if (!lhs.dtype().is_float() && + (rhs.dtype().is_float() || + datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { + // Cast int->float when the other operand is a float lhs = cast(rhs.dtype(), lhs); } else if ((lhs.dtype().is_float() || datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) && !rhs.dtype().is_float()) { - // int->float + // Cast int->float when the other operand is a float rhs = cast(lhs.dtype(), rhs); } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) || (lhs.dtype().is_uint() && rhs.dtype().is_uint())) { - // promote int to higher bits + // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (lhs.dtype().bits() < rhs.dtype().bits()) { lhs = cast(rhs.dtype(), lhs); } else { @@ -137,6 +144,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) || (lhs.dtype().is_uint() && rhs.dtype().is_int())) { + // Handle mixing signed and unsigned integers int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span); rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span); diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 6e081a179059..3928c817a125 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -16,7 +16,10 @@ # under the License. import tvm from tvm import tir +from tvm._ffi.base import TVMError from tvm.ir.transform import PassContext +import itertools +import pytest def build_tir_func(func): @@ -30,15 +33,59 @@ def build_tir_func(func): def test_scalar_add(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") - c = a + b - c = tir.ret(c) - c = tir.Evaluate(c) - func = tir.PrimFunc([a, b], c) - func = build_tir_func(func) - out = func(1.0, 2.0) - assert out == 3.0 + # All these types should be interchangeable with each other + # E.g. float16 + float32 upconverts the float16 --> float32 + # Meanwhile if an int or float or together the int will be + # cast to the float type. + lhs_types = ["float32", "float16", "int32", "int64"] + rhs_types = ["float32", "float16"] + for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types): + # Input vars should be float32, we will cast to test for upcasting between them + lhs_input = tir.Var("lhs", "float32") + rhs_input = tir.Var("rhs", "float32") + lhs = tir.Cast(lhs_type, lhs_input) + rhs = tir.Cast(rhs_type, rhs_input) + output = lhs + rhs + output = tir.ret(output) + output = tir.Evaluate(output) + func = tir.PrimFunc([lhs_input, rhs_input], output) + func = build_tir_func(func) + out = func(1.0, 2.0) + assert out == 3.0 + + +def assignment_helper(store_dtype, value_dtype): + store = tir.Var("store", dtype=store_dtype) + value = tir.Var("value", dtype=value_dtype) + tir.Let(store, value, body=store) + + +def test_fail_implicit_downcasts_same_type(): + # These lists should be sorted + bits = [8, 16, 32, 64] + for type in ["float", "int", "uint"]: + for i in range(len(bits) - 1): + with pytest.raises(TVMError): + assignment_helper( + store_dtype=f"{type}{bits[i]}", value_dtype=f"{type}{bits[i + 1]}" + ) + + +def test_cast_between_types(): + # We should only be able to assign values with the same types + bits = [16, 32] + types = ["float", "int", "uint"] + for store_type, store_bits, value_type, value_bits in itertools.product( + types, bits, types, bits + ): + store_dtype = f"{store_type}{store_bits}" + value_dtype = f"{value_type}{value_bits}" + if store_dtype == value_dtype: + assignment_helper(store_dtype, value_dtype) + else: + # TODO: we might want to allow casts between uint and int types + with pytest.raises(TVMError): + assignment_helper(store_dtype, value_dtype) def test_control_flow_jump(): diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index f1f8cf70d0c9..78eab6bdde9f 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -146,13 +146,22 @@ def verify_callop_float_only(f): rhs = te.var("rhs", dtype=rhs_dtype) if "float" not in lhs_dtype and "float" not in rhs_dtype: check_throws(lambda: f(lhs, rhs)) - elif "float" in lhs_dtype and "float" in rhs_dtype and lhs_dtype != rhs_dtype: - check_throws(lambda: f(lhs, rhs)) elif "float" in lhs_dtype: out = f(lhs, rhs) - assert out.dtype == lhs_dtype - assert out.args[0].dtype == lhs_dtype - assert out.args[1].dtype == lhs_dtype + + # Upcasting for floating point types + dtypes = [lhs_dtype, rhs_dtype] + if "float64" in dtypes: + target_dtype = "float64" + elif "float32" in dtypes: + target_dtype = "float32" + else: + target_dtype = "int32" + assert out.dtype == target_dtype + + # Final inputs are the right type + assert out.args[0].dtype == target_dtype + assert out.args[1].dtype == target_dtype else: out = f(lhs, rhs) assert out.dtype == rhs_dtype