From 9617fd3e8c282a7b57180aa77f75dcd411198230 Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Wed, 27 Apr 2022 15:26:25 +0530 Subject: [PATCH 1/6] Fix mixed precision output type to original type --- src/relay/transforms/to_mixed_precision.cc | 35 ++++++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 4ad3482f7464..313c3faf3c67 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -105,6 +105,8 @@ class MixedPrecisionPass : public MixedModeMutator { * encountered. Used for emitting warnings on missing ops in the pass. */ std::unordered_map missing_ops_; + const RelayExprNode* root_; + std::vector original_dtype_; Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ @@ -278,8 +280,18 @@ class MixedPrecisionPass : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16)) - : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) { + explicit MixedPrecisionPass(Expr base, DataType mixed_precision_type = DataType::Float(16)) + : MixedModeMutator(), + mixed_precision_type_(mixed_precision_type), + root_(Downcast(base)->body.get()) { + if (root_->IsInstance()) { + const TupleTypeNode* tuple_type = (root_->checked_type_).as(); + for (Type t : tuple_type->fields) { + const TensorTypeNode* tensor_type = t.as(); + original_dtype_.push_back(tensor_type->dtype); + } + } else if (root_->IsInstance()) + original_dtype_.push_back((root_->checked_type_).as()->dtype); if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) { LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " << mixed_precision_type_; @@ -381,6 +393,10 @@ class MixedPrecisionPass : public MixedModeMutator { if (accumulation_dtype != output_dtype) { output = CastArg(output, GetType(output), output_dtype); } + if (pre_call_node == static_cast(root_)) { + if (original_dtype_[0] != output_dtype) + output = CastArg(output, GetType(output), original_dtype_[0]); + } return output; } @@ -396,6 +412,19 @@ class MixedPrecisionPass : public MixedModeMutator { Expr Rewrite_(const TupleNode* pre, const Expr& post) { // The old checked type in the expression may not be valid so clear it post->checked_type_ = Type(nullptr); + if (pre == root_) { + Array new_expr; + bool all_same = true; + for (size_t i = 0; i < original_dtype_.size(); i++) { + Expr output_element = GetField(post, i); + Expr casted_element; + auto output_element_type = transform::InferTypeLocal(output_element); + casted_element = CastArg(output_element, output_element_type, original_dtype_[i]); + new_expr.push_back(casted_element); + all_same &= casted_element.same_as(output_element); + } + if (!all_same) return Tuple(new_expr); + } return post; } @@ -436,7 +465,7 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2) << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; - MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type); + MixedPrecisionPass converter = MixedPrecisionPass(expr, mixed_precision_type); auto result = converter.Mutate(expr); for (auto it = converter.missing_ops_.begin(); From 5dfd53c6bae8addb5c1143916ec86ef840d2efd5 Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Sat, 30 Apr 2022 00:05:11 +0530 Subject: [PATCH 2/6] Add configure to PassContext to fix mixed precision output type along with unit test --- src/relay/transforms/to_mixed_precision.cc | 53 ++++++++++++------- tests/python/relay/test_to_mixed_precision.py | 39 ++++++++++---- 2 files changed, 64 insertions(+), 28 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 313c3faf3c67..b5385758b326 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -36,6 +36,7 @@ namespace tvm { namespace relay { +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.enable_original_type", Bool); // A callable which hashes std::pair struct pair_hash { template @@ -107,6 +108,7 @@ class MixedPrecisionPass : public MixedModeMutator { std::unordered_map missing_ops_; const RelayExprNode* root_; std::vector original_dtype_; + bool enable_original_type_; Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ @@ -280,18 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit MixedPrecisionPass(Expr base, DataType mixed_precision_type = DataType::Float(16)) + explicit MixedPrecisionPass(Expr base, bool enable_original_type, + DataType mixed_precision_type = DataType::Float(16)) : MixedModeMutator(), mixed_precision_type_(mixed_precision_type), - root_(Downcast(base)->body.get()) { - if (root_->IsInstance()) { - const TupleTypeNode* tuple_type = (root_->checked_type_).as(); - for (Type t : tuple_type->fields) { - const TensorTypeNode* tensor_type = t.as(); - original_dtype_.push_back(tensor_type->dtype); + root_(Downcast(base)->body.get()), + enable_original_type_(enable_original_type) { + if (enable_original_type_) { + if (root_->IsInstance()) { + const TupleTypeNode* tuple_type = (root_->checked_type_).as(); + for (Type t : tuple_type->fields) { + const TensorTypeNode* tensor_type = t.as(); + original_dtype_.push_back(tensor_type->dtype); + } + } else if (root_->IsInstance()) { + original_dtype_.push_back((root_->checked_type_).as()->dtype); } - } else if (root_->IsInstance()) - original_dtype_.push_back((root_->checked_type_).as()->dtype); + } if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) { LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " << mixed_precision_type_; @@ -393,9 +400,10 @@ class MixedPrecisionPass : public MixedModeMutator { if (accumulation_dtype != output_dtype) { output = CastArg(output, GetType(output), output_dtype); } - if (pre_call_node == static_cast(root_)) { - if (original_dtype_[0] != output_dtype) + if (pre_call_node == static_cast(root_) && enable_original_type_) { + if (original_dtype_[0] != output_dtype) { output = CastArg(output, GetType(output), original_dtype_[0]); + } } return output; } @@ -412,7 +420,7 @@ class MixedPrecisionPass : public MixedModeMutator { Expr Rewrite_(const TupleNode* pre, const Expr& post) { // The old checked type in the expression may not be valid so clear it post->checked_type_ = Type(nullptr); - if (pre == root_) { + if (pre == root_ && enable_original_type_) { Array new_expr; bool all_same = true; for (size_t i = 0; i < original_dtype_.size(); i++) { @@ -423,7 +431,9 @@ class MixedPrecisionPass : public MixedModeMutator { new_expr.push_back(casted_element); all_same &= casted_element.same_as(output_element); } - if (!all_same) return Tuple(new_expr); + if (!all_same) { + return Tuple(new_expr); + } } return post; } @@ -450,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator { } // To access map of ops not registered for error reporting - friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, - int missing_op_mode); + friend Expr ToMixedPrecision(const Expr& expr, bool enable_original_type, + const DataType& mixed_precision_type, int missing_op_mode); }; -Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) { +Expr ToMixedPrecision(const Expr& expr, bool enable_original_type, + const DataType& mixed_precision_type, int missing_op_mode) { /* missing_op_mode: @@ -465,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2) << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; - MixedPrecisionPass converter = MixedPrecisionPass(expr, mixed_precision_type); + MixedPrecisionPass converter = + MixedPrecisionPass(expr, enable_original_type, mixed_precision_type); auto result = converter.Mutate(expr); for (auto it = converter.missing_ops_.begin(); @@ -489,7 +501,12 @@ namespace transform { Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(ToMixedPrecision(f, mixed_precision_type, missing_op_mode)); + bool enable_original_type = false; + enable_original_type = + pc->GetConfig("relay.ToMixedPrecision.enable_original_type", Bool(enable_original_type)) + .value(); + return Downcast( + ToMixedPrecision(f, enable_original_type, mixed_precision_type, missing_op_mode)); }; return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 2afd6ff247ab..f242d7ec99ce 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -41,17 +41,31 @@ def verify_mixed_precision_output_close( mixed_precision_dtype="float16", rtol: float = 1e-3, atol: float = 0, + enable_original_type=False, ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) - fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) - result_fp16 = run_module(fp16_mod, mod_params) + + if enable_original_type == False: + fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_fp16 = run_module(fp16_mod, mod_params) + else: + with tvm.transform.PassContext( + config={"relay.ToMixedPrecision.enable_original_type": True} + ): + fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_fp16 = run_module(fp16_mod, mod_params) # Ensure the results are close for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) + if enable_original_type: + assert ( + np.array(result_fp16).dtype == np.array(result_fp32).dtype + ), "output type and original type mismatch" + return fp16_mod @@ -117,16 +131,21 @@ def test_convert_single_conv(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close( + mod, mod_params, atol=0.01, rtol=1e-3, enable_original_type=True + ) expected_mod = tvm.IRModule.from_expr( - relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float16", - ), + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", + ), + "float32" + ) ) expected_mod = tvm.relay.transform.InferType()(expected_mod) From 14cf90e2c423b47a45638578601b60f1a2b0836b Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Mon, 2 May 2022 21:35:38 +0530 Subject: [PATCH 3/6] Rename the Pass config option to keep_orig_output_dtype --- src/relay/transforms/to_mixed_precision.cc | 30 +++++++++---------- tests/python/relay/test_to_mixed_precision.py | 10 +++---- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index b5385758b326..e1d3a264c222 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -36,7 +36,7 @@ namespace tvm { namespace relay { -TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.enable_original_type", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool); // A callable which hashes std::pair struct pair_hash { template @@ -108,7 +108,7 @@ class MixedPrecisionPass : public MixedModeMutator { std::unordered_map missing_ops_; const RelayExprNode* root_; std::vector original_dtype_; - bool enable_original_type_; + bool keep_orig_output_dtype_; Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ @@ -282,13 +282,13 @@ class MixedPrecisionPass : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - explicit MixedPrecisionPass(Expr base, bool enable_original_type, + explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype, DataType mixed_precision_type = DataType::Float(16)) : MixedModeMutator(), mixed_precision_type_(mixed_precision_type), root_(Downcast(base)->body.get()), - enable_original_type_(enable_original_type) { - if (enable_original_type_) { + keep_orig_output_dtype_(keep_orig_output_dtype) { + if (keep_orig_output_dtype_) { if (root_->IsInstance()) { const TupleTypeNode* tuple_type = (root_->checked_type_).as(); for (Type t : tuple_type->fields) { @@ -400,7 +400,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (accumulation_dtype != output_dtype) { output = CastArg(output, GetType(output), output_dtype); } - if (pre_call_node == static_cast(root_) && enable_original_type_) { + if (pre_call_node == root_ && keep_orig_output_dtype_) { if (original_dtype_[0] != output_dtype) { output = CastArg(output, GetType(output), original_dtype_[0]); } @@ -420,7 +420,7 @@ class MixedPrecisionPass : public MixedModeMutator { Expr Rewrite_(const TupleNode* pre, const Expr& post) { // The old checked type in the expression may not be valid so clear it post->checked_type_ = Type(nullptr); - if (pre == root_ && enable_original_type_) { + if (pre == root_ && keep_orig_output_dtype_) { Array new_expr; bool all_same = true; for (size_t i = 0; i < original_dtype_.size(); i++) { @@ -460,11 +460,11 @@ class MixedPrecisionPass : public MixedModeMutator { } // To access map of ops not registered for error reporting - friend Expr ToMixedPrecision(const Expr& expr, bool enable_original_type, + friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype, const DataType& mixed_precision_type, int missing_op_mode); }; -Expr ToMixedPrecision(const Expr& expr, bool enable_original_type, +Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype, const DataType& mixed_precision_type, int missing_op_mode) { /* missing_op_mode: @@ -477,7 +477,7 @@ Expr ToMixedPrecision(const Expr& expr, bool enable_original_type, << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; MixedPrecisionPass converter = - MixedPrecisionPass(expr, enable_original_type, mixed_precision_type); + MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type); auto result = converter.Mutate(expr); for (auto it = converter.missing_ops_.begin(); @@ -501,12 +501,12 @@ namespace transform { Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - bool enable_original_type = false; - enable_original_type = - pc->GetConfig("relay.ToMixedPrecision.enable_original_type", Bool(enable_original_type)) - .value(); + bool keep_orig_output_dtype = false; + keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype", + Bool(keep_orig_output_dtype)) + .value(); return Downcast( - ToMixedPrecision(f, enable_original_type, mixed_precision_type, missing_op_mode)); + ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode)); }; return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index f242d7ec99ce..5852b768ff2e 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -41,18 +41,18 @@ def verify_mixed_precision_output_close( mixed_precision_dtype="float16", rtol: float = 1e-3, atol: float = 0, - enable_original_type=False, + keep_orig_output_dtype=False, ) -> tvm.runtime.Module: mod = InferType()(mod) result_fp32 = run_module(mod, mod_params) - if enable_original_type == False: + if not keep_orig_output_dtype: fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) else: with tvm.transform.PassContext( - config={"relay.ToMixedPrecision.enable_original_type": True} + config={"relay.ToMixedPrecision.keep_orig_output_dtype": True} ): fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) @@ -61,7 +61,7 @@ def verify_mixed_precision_output_close( for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) - if enable_original_type: + if keep_orig_output_dtype: assert ( np.array(result_fp16).dtype == np.array(result_fp32).dtype ), "output type and original type mismatch" @@ -132,7 +132,7 @@ def test_convert_single_conv(): "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } fp16_mod = verify_mixed_precision_output_close( - mod, mod_params, atol=0.01, rtol=1e-3, enable_original_type=True + mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True ) expected_mod = tvm.IRModule.from_expr( From 05f1ca38f560f08390e3a51926560932809f767c Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Mon, 2 May 2022 22:25:18 +0530 Subject: [PATCH 4/6] Fix file reformatted --- tests/python/relay/test_to_mixed_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 5852b768ff2e..026b458bde12 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -144,7 +144,7 @@ def test_convert_single_conv(): padding=(1, 1), out_dtype="float16", ), - "float32" + "float32", ) ) expected_mod = tvm.relay.transform.InferType()(expected_mod) From f2da139526b548d3a36cc3fb27b53a49dc476624 Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Wed, 4 May 2022 07:08:31 +0530 Subject: [PATCH 5/6] Trigger Build From aab033f9f475615539b67be2762035a3cb9be267 Mon Sep 17 00:00:00 2001 From: Gayatri Panchapakesan Kumari Date: Wed, 4 May 2022 11:37:14 +0530 Subject: [PATCH 6/6] Trigger Build2