diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 4638ee547706..5026b1bcba79 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -350,10 +350,11 @@ class MixedPrecisionPass : public MixedModeMutator { // TODO(AndrewZhaoLuo): Support ADTs // Relay's algebraic data types are not supported yet. - ICHECK(!cur_op.as() // used to declare functions for recursion - && !cur_op.as() // constructing ADT types - && !cur_op.as()) // used for calling recursive functions - << "Algebraic Data Types (ADT) are not supported yet for mixed precision pass."; + bool isADT = (cur_op.as() // used to declare functions for recursion + || cur_op.as() // constructing ADT types + || cur_op.as() // used for binding lambdas + || cur_op.as()); // used for calling recursive functions + if (isADT) return post; // Get info on the operation being called: // conversion category (int), accumulation dtype (str), output dtype (str) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index a802eee6d644..4c97642498d9 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -49,7 +49,6 @@ def verify_mixed_precision_output_close( atol: float = 0, keep_orig_output_dtype=False, ) -> tvm.runtime.Module: - mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) @@ -586,5 +585,39 @@ def test_clip_with_pre_op(target_precision): assert tvm.ir.structural_equal(expected_mod, output_mod) +def test_loop(target_precision): + i = relay.var("i", shape=(), dtype="int32") + st = relay.var("st", shape=(relay.Any(), 1), dtype="int32") + + def int32(val): + return relay.const(val, "int32") + + def _cond(i, st): + return relay.op.min(relay.op.less(i, int32(10))) + + def _body(i, st): + i_vec = relay.op.reshape(i, (1, 1)) + ret = relay.op.concatenate([st, i_vec], axis=0) + return i + int32(1), ret + + loop = relay.loops.while_loop(_cond, [i, st], _body) + start = relay.var("start", shape=(), dtype="int32") + body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) + func = relay.Function([start], relay.TupleGetItem(body, 1)) + mod = tvm.IRModule() + mod["main"] = func + + mod_params = { + "start": np.random.uniform(-1, 1, size=()).astype("int32"), + } + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) + + # Create expected module + expected_mod = InferType()(mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + if __name__ == "__main__": tvm.testing.main()