From d43d50fab0d9b8c79e16995e14712a934aed1e34 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Wed, 2 Mar 2022 20:53:31 +0800 Subject: [PATCH 01/11] fix InferType bug --- src/relay/transforms/type_infer.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 7de43eb36882..2561e140b3f6 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -771,7 +771,7 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { this->diag_ctx.Emit(Diagnostic::Bug(function->span) << "the type checked function is malformed, please report this"); } - + type_map_[var].checked_type = resolved_expr->checked_type_; return resolved_expr; } @@ -940,6 +940,7 @@ Pass InferType() { AddGlobalTypes(mod); std::vector > updates; + auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); for (const auto& it : updated_mod->functions) { // Currently we don't type check TIR. // @@ -957,7 +958,7 @@ Pass InferType() { // TODO(@jroesch): we should be able to move the type inferencer outside // of this function but it seems to be more stateful then I expect. - auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); + // auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); auto updated_func = inferencer.Infer(it.first, func); pass_ctx->diag_ctx.value().Render(); From dcf68c799ac56c42ccb0896cb454f9c2cc1ac15f Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Thu, 3 Mar 2022 12:33:12 +0800 Subject: [PATCH 02/11] fix InferType related bug --- src/relay/transforms/type_infer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 2561e140b3f6..7153f06d7418 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -771,7 +771,7 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { this->diag_ctx.Emit(Diagnostic::Bug(function->span) << "the type checked function is malformed, please report this"); } - type_map_[var].checked_type = resolved_expr->checked_type_; + type_map_[var].checked_type = resolved_expr->checked_type_; return resolved_expr; } From accdfc3ff917fb55aa8b59721da3883c6afe2121 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 13:08:30 +0800 Subject: [PATCH 03/11] check if uint variable is negative --- src/arith/const_fold.h | 3 ++- src/ir/expr.cc | 1 + src/relay/backend/build_module.cc | 2 +- src/relay/transforms/type_infer.cc | 5 ++--- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 7bc04a184633..d916bafb8874 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -108,6 +108,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ + ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint()))<< "Minuend 's value is 0U and it's dtype is uint, while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; @@ -119,7 +120,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - TVM_ARITH_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value * pb->value); if (pa) { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 399873492f04..f249602541b0 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -61,6 +61,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied."; if (dtype.is_uint()) { + // std::cout << "is_uint value is " << value << std::endl; // haoyang ICHECK_GE(value, 0U); } ObjectPtr node = make_object(); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 89ee61c83f7c..9bfda9a90914 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -425,7 +425,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); executor_codegen_->Init(nullptr, config_->legacy_target_map); - executor_codegen_->Codegen(func_module, func, mod_name); + executor_codegen_->Codegen(func_module, func, mod_name); // haoyang here executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 7153f06d7418..dfcd44264b62 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -771,7 +771,6 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { this->diag_ctx.Emit(Diagnostic::Bug(function->span) << "the type checked function is malformed, please report this"); } - type_map_[var].checked_type = resolved_expr->checked_type_; return resolved_expr; } @@ -940,7 +939,7 @@ Pass InferType() { AddGlobalTypes(mod); std::vector > updates; - auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); + for (const auto& it : updated_mod->functions) { // Currently we don't type check TIR. // @@ -958,7 +957,7 @@ Pass InferType() { // TODO(@jroesch): we should be able to move the type inferencer outside // of this function but it seems to be more stateful then I expect. - // auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); + auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); auto updated_func = inferencer.Infer(it.first, func); pass_ctx->diag_ctx.value().Render(); From 7e2843f73dda0ce77fd7bae1f3ef34a716e2a607 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 13:11:43 +0800 Subject: [PATCH 04/11] check if uint variable is negative --- src/arith/const_fold.h | 2 +- src/ir/expr.cc | 1 - src/relay/backend/build_module.cc | 2 +- src/relay/transforms/type_infer.cc | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index d916bafb8874..32d30b30836c 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -120,7 +120,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - TVM_ARITH_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value * pb->value); if (pa) { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f249602541b0..399873492f04 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -61,7 +61,6 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied."; if (dtype.is_uint()) { - // std::cout << "is_uint value is " << value << std::endl; // haoyang ICHECK_GE(value, 0U); } ObjectPtr node = make_object(); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 9bfda9a90914..89ee61c83f7c 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -425,7 +425,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); executor_codegen_->Init(nullptr, config_->legacy_target_map); - executor_codegen_->Codegen(func_module, func, mod_name); // haoyang here + executor_codegen_->Codegen(func_module, func, mod_name); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index dfcd44264b62..aff98aa4b67a 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -939,7 +939,7 @@ Pass InferType() { AddGlobalTypes(mod); std::vector > updates; - + for (const auto& it : updated_mod->functions) { // Currently we don't type check TIR. // From 0f429c2117637f8d75648f359e1463068ecb624c Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 14:02:58 +0800 Subject: [PATCH 05/11] check if uint variable is negative --- src/relay/transforms/type_infer.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index aff98aa4b67a..1383a9eb8f0b 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -771,6 +771,7 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { this->diag_ctx.Emit(Diagnostic::Bug(function->span) << "the type checked function is malformed, please report this"); } + return resolved_expr; } From 72379cd3e0fddcd748adef85ac01c6802e9d4c53 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 14:04:05 +0800 Subject: [PATCH 06/11] check if uint variable is negative --- src/relay/transforms/type_infer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 1383a9eb8f0b..7de43eb36882 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -940,7 +940,6 @@ Pass InferType() { AddGlobalTypes(mod); std::vector > updates; - for (const auto& it : updated_mod->functions) { // Currently we don't type check TIR. // From 1a844984b28e68e3a99cd814a29e070b3a3f4aa9 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 14:07:54 +0800 Subject: [PATCH 07/11] check if uint variable is negative --- src/arith/const_fold.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 32d30b30836c..c2abe251ce12 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -108,7 +108,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint()))<< "Minuend 's value is 0U and it's dtype is uint, while Subtrahend's dtype is uint; which will cause a negative uint"; + ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint()))<<\ + "Minuend 's value is 0U and it's dtype is uint, while Subtrahend's \ + dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; From 9711c10ea3a5834a2a57d50cbc537378afa7697e Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 14:15:43 +0800 Subject: [PATCH 08/11] check if uint variable is negative --- src/arith/const_fold.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index c2abe251ce12..305233026acc 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -108,9 +108,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint()))<<\ - "Minuend 's value is 0U and it's dtype is uint, while Subtrahend's \ - dtype is uint; which will cause a negative uint"; + ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint())) << \ + "Checked failed. Minuend 's value is 0U and it's dtype is uint " << \ + "while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; From 62aa32b129f0a0d86e6ae32230471c0c948aee25 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 14:40:24 +0800 Subject: [PATCH 09/11] check if uint variable is negative --- src/arith/const_fold.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 305233026acc..33694d6b1d5d 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -108,8 +108,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint())) << \ - "Checked failed. Minuend 's value is 0U and it's dtype is uint " << \ + ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint())) << + "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); From dfb927b762298e92937c9c1bcc67991c003f9258 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 14:44:06 +0800 Subject: [PATCH 10/11] check if uint variable is negative --- src/arith/const_fold.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 33694d6b1d5d..7d19e0de3ee7 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -108,8 +108,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint())) << - "Checked failed. Minuend 's value is 0U and it's dtype is uint " << + ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint())) << + "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); From 5251121ed4c4a4e5d4904471faa9a9c42eaf0c3e Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Fri, 4 Mar 2022 14:48:38 +0800 Subject: [PATCH 11/11] check if uint variable is negative --- src/arith/const_fold.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 7d19e0de3ee7..0e675c6806a2 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -108,9 +108,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint())) << - "Checked failed. Minuend 's value is 0U and it's dtype is uint " << - "while Subtrahend's dtype is uint; which will cause a negative uint"; + ICHECK(!(pa && pa->dtype.is_uint() && pa->value == 0U && b.dtype().is_uint())) + << "Checked failed. Minuend 's value is 0U and it's dtype is uint " + << "while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a;