From b1f6467dfa36c347b2eda8f0eced8373f52a9363 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 4 Aug 2021 12:45:38 +0900 Subject: [PATCH 1/3] [AMP] Do not allow fp16 cast on arange inputs --- python/tvm/relay/transform/mixed_precision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 1e982a0f18a4..085ed9dba011 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -128,6 +128,8 @@ # Error function doesn't seem to be able to be lowered into fp16 version in llvm. # Move to follow list when it does. "erf", + # Do not allow arange arguments (begin/end) to be fp16 + "arange", ] From 082b1ff222b39c4c7d6a4dcd765e1e5435a5a583 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 4 Aug 2021 13:59:58 +0900 Subject: [PATCH 2/3] add test --- tests/python/relay/test_to_mixed_precision.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 7a3fbfafc089..1eac975e60da 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -230,6 +230,17 @@ def test_do_not_convert_softmax(): assert tvm.ir.structural_equal(mod, output_mod) +def test_do_not_convert_arange(): + """Arange is a red listed operation and therefore should never be fp16.""" + dtype = "float32" + arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) + mod = tvm.IRModule.from_expr(arange) + mod = tvm.relay.transform.InferType()(mod) + + output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0) + assert tvm.ir.structural_equal(mod, output_mod) + + def test_green_gray_propagates_simple(): """Conv is a green listed operation, while addition is gray. From e21470b7eb46c1471a9007834ffe734000b99e21 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 4 Aug 2021 15:45:43 +0900 Subject: [PATCH 3/3] Add comment explaining the issue with fp16 "end" --- python/tvm/relay/transform/mixed_precision.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 085ed9dba011..1657f895dcd7 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -128,7 +128,8 @@ # Error function doesn't seem to be able to be lowered into fp16 version in llvm. # Move to follow list when it does. "erf", - # Do not allow arange arguments (begin/end) to be fp16 + # Do not allow arange arguments (begin/end) to be fp16. "end" can be a big fp32 number + # not representable in fp16. "arange", ]