diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 1e982a0f18a4..1657f895dcd7 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -128,6 +128,9 @@ # Error function doesn't seem to be able to be lowered into fp16 version in llvm. # Move to follow list when it does. "erf", + # Do not allow arange arguments (begin/end) to be fp16. "end" can be a big fp32 number + # not representable in fp16. + "arange", ] diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 7a3fbfafc089..1eac975e60da 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -230,6 +230,17 @@ def test_do_not_convert_softmax(): assert tvm.ir.structural_equal(mod, output_mod) +def test_do_not_convert_arange(): + """Arange is a red listed operation and therefore should never be fp16.""" + dtype = "float32" + arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) + mod = tvm.IRModule.from_expr(arange) + mod = tvm.relay.transform.InferType()(mod) + + output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0) + assert tvm.ir.structural_equal(mod, output_mod) + + def test_green_gray_propagates_simple(): """Conv is a green listed operation, while addition is gray.