From 700e7207d6b819cda37cfb6909a054d99eef3ccd Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Fri, 11 Aug 2023 23:18:03 +0300 Subject: [PATCH] [Relay] Disable exception for ADT in mixed precision pass If topology contains while loop and we want to transform it to mixed precision then we get an exception that "ADT are not supported for mixed precision pass". It happens, because while loop implemented as a lambda which is assigned to a VarNode. In this commit I changed the behavior of ToMixedPrecision pass and instead of generating exception, it just do nothing. Correspondent regression test is added. --- src/relay/transforms/to_mixed_precision.cc | 9 ++--- tests/python/relay/test_to_mixed_precision.py | 35 ++++++++++++++++++- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 4638ee547706..5026b1bcba79 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -350,10 +350,11 @@ class MixedPrecisionPass : public MixedModeMutator { // TODO(AndrewZhaoLuo): Support ADTs // Relay's algebraic data types are not supported yet. - ICHECK(!cur_op.as() // used to declare functions for recursion - && !cur_op.as() // constructing ADT types - && !cur_op.as()) // used for calling recursive functions - << "Algebraic Data Types (ADT) are not supported yet for mixed precision pass."; + bool isADT = (cur_op.as() // used to declare functions for recursion + || cur_op.as() // constructing ADT types + || cur_op.as() // used for binding lambdas + || cur_op.as()); // used for calling recursive functions + if (isADT) return post; // Get info on the operation being called: // conversion category (int), accumulation dtype (str), output dtype (str) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index a802eee6d644..4c97642498d9 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -49,7 +49,6 @@ def verify_mixed_precision_output_close( atol: float = 0, keep_orig_output_dtype=False, ) -> tvm.runtime.Module: - mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) @@ -586,5 +585,39 @@ def test_clip_with_pre_op(target_precision): assert tvm.ir.structural_equal(expected_mod, output_mod) +def test_loop(target_precision): + i = relay.var("i", shape=(), dtype="int32") + st = relay.var("st", shape=(relay.Any(), 1), dtype="int32") + + def int32(val): + return relay.const(val, "int32") + + def _cond(i, st): + return relay.op.min(relay.op.less(i, int32(10))) + + def _body(i, st): + i_vec = relay.op.reshape(i, (1, 1)) + ret = relay.op.concatenate([st, i_vec], axis=0) + return i + int32(1), ret + + loop = relay.loops.while_loop(_cond, [i, st], _body) + start = relay.var("start", shape=(), dtype="int32") + body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) + func = relay.Function([start], relay.TupleGetItem(body, 1)) + mod = tvm.IRModule() + mod["main"] = func + + mod_params = { + "start": np.random.uniform(-1, 1, size=()).astype("int32"), + } + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) + + # Create expected module + expected_mod = InferType()(mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + if __name__ == "__main__": tvm.testing.main()