Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,39 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype;
}
if (lhs.dtype() == rhs.dtype()) return;
// Only do very simple type coversion
// int->float, DataType::Int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
if (!lhs.dtype().is_float() &&
(rhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
// int->float

// We keep dtypes conversion to be relatively consistent to reduce the amount code generated by
// operators. This can be helpful for users to find potential type conversion problems. The
// following are exceptions:
if (lhs.dtype().is_float() && rhs.dtype().is_float()) {
// Given two dissimilar floats, cast the lower bit version to the higher bit version.
// E.g. fp16 + fp32 --> fp32 + fp32
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
} else if (lhs.dtype().bits() > rhs.dtype().bits()) {
rhs = cast(lhs.dtype(), rhs);
}
} else if (!lhs.dtype().is_float() &&
(rhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
// Cast int->float when the other operand is a float
lhs = cast(rhs.dtype(), lhs);
} else if ((lhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) &&
!rhs.dtype().is_float()) {
// int->float
// Cast int->float when the other operand is a float
rhs = cast(lhs.dtype(), rhs);
} else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_uint())) {
// promote int to higher bits
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
} else {
rhs = cast(lhs.dtype(), rhs);
}
} else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_int())) {
// Handle mixing signed and unsigned integers
int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span);
rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span);
Expand Down
65 changes: 56 additions & 9 deletions tests/python/unittest/test_tir_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
# under the License.
import tvm
from tvm import tir
from tvm._ffi.base import TVMError
from tvm.ir.transform import PassContext
import itertools
import pytest


def build_tir_func(func):
Expand All @@ -30,15 +33,59 @@ def build_tir_func(func):


def test_scalar_add():
a = tir.Var("a", "float32")
b = tir.Var("b", "float32")
c = a + b
c = tir.ret(c)
c = tir.Evaluate(c)
func = tir.PrimFunc([a, b], c)
func = build_tir_func(func)
out = func(1.0, 2.0)
assert out == 3.0
# All these types should be interchangeable with each other
# E.g. float16 + float32 upconverts the float16 --> float32
# Meanwhile if an int or float or together the int will be
# cast to the float type.
lhs_types = ["float32", "float16", "int32", "int64"]
rhs_types = ["float32", "float16"]
for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types):
# Input vars should be float32, we will cast to test for upcasting between them
lhs_input = tir.Var("lhs", "float32")
rhs_input = tir.Var("rhs", "float32")
lhs = tir.Cast(lhs_type, lhs_input)
rhs = tir.Cast(rhs_type, rhs_input)
output = lhs + rhs
output = tir.ret(output)
output = tir.Evaluate(output)
func = tir.PrimFunc([lhs_input, rhs_input], output)
func = build_tir_func(func)
out = func(1.0, 2.0)
assert out == 3.0


def assignment_helper(store_dtype, value_dtype):
store = tir.Var("store", dtype=store_dtype)
value = tir.Var("value", dtype=value_dtype)
tir.Let(store, value, body=store)


def test_fail_implicit_downcasts_same_type():
# These lists should be sorted
bits = [8, 16, 32, 64]
for type in ["float", "int", "uint"]:
for i in range(len(bits) - 1):
with pytest.raises(TVMError):
assignment_helper(
store_dtype=f"{type}{bits[i]}", value_dtype=f"{type}{bits[i + 1]}"
)


def test_cast_between_types():
# We should only be able to assign values with the same types
bits = [16, 32]
types = ["float", "int", "uint"]
for store_type, store_bits, value_type, value_bits in itertools.product(
types, bits, types, bits
):
store_dtype = f"{store_type}{store_bits}"
value_dtype = f"{value_type}{value_bits}"
if store_dtype == value_dtype:
assignment_helper(store_dtype, value_dtype)
else:
# TODO: we might want to allow casts between uint and int types
with pytest.raises(TVMError):
assignment_helper(store_dtype, value_dtype)


def test_control_flow_jump():
Expand Down
19 changes: 14 additions & 5 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,22 @@ def verify_callop_float_only(f):
rhs = te.var("rhs", dtype=rhs_dtype)
if "float" not in lhs_dtype and "float" not in rhs_dtype:
check_throws(lambda: f(lhs, rhs))
elif "float" in lhs_dtype and "float" in rhs_dtype and lhs_dtype != rhs_dtype:
check_throws(lambda: f(lhs, rhs))
elif "float" in lhs_dtype:
out = f(lhs, rhs)
assert out.dtype == lhs_dtype
assert out.args[0].dtype == lhs_dtype
assert out.args[1].dtype == lhs_dtype

# Upcasting for floating point types
dtypes = [lhs_dtype, rhs_dtype]
if "float64" in dtypes:
target_dtype = "float64"
elif "float32" in dtypes:
target_dtype = "float32"
else:
target_dtype = "int32"
assert out.dtype == target_dtype

# Final inputs are the right type
assert out.args[0].dtype == target_dtype
assert out.args[1].dtype == target_dtype
else:
out = f(lhs, rhs)
assert out.dtype == rhs_dtype
Expand Down