From f1294854e216758dc0dc070dc89d243d9549a494 Mon Sep 17 00:00:00 2001 From: Raghav-Chakravarthy Date: Fri, 17 Sep 2021 19:54:25 -0400 Subject: [PATCH 1/7] [Code Style] Changed code to match the tvm code style conventions. [Issue] While reviewing the tvm code, I noticed some naming convention issues in the diag_ctx_ and current_func variables. Variable current_func should be current_func_ because it is a class variable Variable diag_ctx_ should be diag_ctx , because it is a public variable [Solution] Changed the variables to match the tvm code style conventions --- src/relay/analysis/type_solver.cc | 14 +++++++------- src/relay/analysis/type_solver.h | 4 ++-- src/relay/transforms/type_infer.cc | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 22e2e9a71040..091b41f13623 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -61,7 +61,7 @@ class TypeSolver::Reporter : public TypeReporterNode { TVM_DLL Span GetSpan() final { return this->span; } - TVM_DLL DiagnosticContext GetDiagCtx() final { return this->solver_->diag_ctx_; } + TVM_DLL DiagnosticContext GetDiagCtx() final { return this->solver_->diag_ctx; } // TVM_DLL void Emit(Diagnostic diagnostic) final { // return this->solver_-> @@ -131,7 +131,7 @@ class TypeSolver::Unifier : public TypeFunctor { Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); if (!resolved.defined()) { - solver_->diag_ctx_.Emit( + solver_->diag_ctx.Emit( Diagnostic::Error(this->span) << "The Relay type checker is unable to show the following types match.\n" << "In particular " @@ -233,7 +233,7 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span) + this->solver_->diag_ctx.Emit(Diagnostic::Error(this->span) << "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size() << " dimensions, while `" << PrettyPrint(tt2) << "` has " << tt2->shape.size() @@ -266,7 +266,7 @@ class TypeSolver::Unifier : public TypeFunctor { err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch) << "."; } - this->solver_->diag_ctx_.Emit(err); + this->solver_->diag_ctx.Emit(err); return Type(nullptr); } @@ -526,8 +526,8 @@ class TypeSolver::Merger : public TypeFunctor { // constructor TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) : reporter_(make_object(this)), - current_func(current_func), - diag_ctx_(diag_ctx), + current_func_(current_func), + diag_ctx(diag_ctx), module_(diag_ctx->module) { ICHECK(module_.defined()); } @@ -618,7 +618,7 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const CompileError& err) { - this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what()); + this->diag_ctx.Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; } catch (const Error& e) { ICHECK(false) << e.what(); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 56cea60ceeda..c2cc248e5823 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -176,11 +176,11 @@ class TypeSolver { /*! \brief Reporter that reports back to self */ TypeReporter reporter_; /*! \brief The global representing the current function. */ - GlobalVar current_func; + GlobalVar current_func_; public: /*! \brief The diagnostic context. */ - DiagnosticContext diag_ctx_; + DiagnosticContext diag_ctx; private: /*! \brief The module. */ diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 6c2371716b16..bf8d082ca2fe 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -665,7 +665,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Type checked_type = solver_->Resolve(it->second.checked_type); if (checked_type.as() != nullptr) { - this->solver_->diag_ctx_.Emit( + this->solver_->diag_ctx.Emit( Diagnostic::Error(op->span) << "The type inference pass was unable to infer a type for this expression.\n" << "This usually occurs when an operator call is under constrained in some way," From 342253def0f18e7d3016a512c2f96bbfc488c716 Mon Sep 17 00:00:00 2001 From: Raghav-Chakravarthy Date: Fri, 8 Oct 2021 16:43:15 -0400 Subject: [PATCH 2/7] addressed comments --- src/relay/analysis/type_solver.cc | 20 ++++++++++---------- src/relay/analysis/type_solver.h | 8 ++------ src/relay/transforms/type_infer.cc | 4 ++-- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 091b41f13623..965b196a4572 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -61,7 +61,7 @@ class TypeSolver::Reporter : public TypeReporterNode { TVM_DLL Span GetSpan() final { return this->span; } - TVM_DLL DiagnosticContext GetDiagCtx() final { return this->solver_->diag_ctx; } + TVM_DLL DiagnosticContext GetDiagCtx() final { return this->solver_->diag_ctx_; } // TVM_DLL void Emit(Diagnostic diagnostic) final { // return this->solver_-> @@ -131,7 +131,7 @@ class TypeSolver::Unifier : public TypeFunctor { Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); if (!resolved.defined()) { - solver_->diag_ctx.Emit( + solver_->Emit( Diagnostic::Error(this->span) << "The Relay type checker is unable to show the following types match.\n" << "In particular " @@ -233,11 +233,11 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->diag_ctx.Emit(Diagnostic::Error(this->span) - << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions"); + this->solver_->Emit(Diagnostic::Error(this->span) + << "tensor type `" << PrettyPrint(tt1) << "` has " + << tt1->shape.size() << " dimensions, while `" + << PrettyPrint(tt2) << "` has " << tt2->shape.size() + << " dimensions"); return Type(nullptr); } @@ -266,7 +266,7 @@ class TypeSolver::Unifier : public TypeFunctor { err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch) << "."; } - this->solver_->diag_ctx.Emit(err); + this->solver_->Emit(err); return Type(nullptr); } @@ -527,7 +527,7 @@ class TypeSolver::Merger : public TypeFunctor { TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) : reporter_(make_object(this)), current_func_(current_func), - diag_ctx(diag_ctx), + diag_ctx_(diag_ctx), module_(diag_ctx->module) { ICHECK(module_.defined()); } @@ -618,7 +618,7 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const CompileError& err) { - this->diag_ctx.Emit(Diagnostic::Error(rnode->span) << err.what()); + this->Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; } catch (const Error& e) { ICHECK(false) << e.what(); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index c2cc248e5823..878f91cb6693 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -94,7 +94,7 @@ class TypeSolver { * \brief Report a diagnostic. * \param diag The diagnostic to report. */ - void EmitDiagnostic(const Diagnostic& diag); + void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }; private: class OccursChecker; @@ -177,12 +177,8 @@ class TypeSolver { TypeReporter reporter_; /*! \brief The global representing the current function. */ GlobalVar current_func_; - - public: /*! \brief The diagnostic context. */ - DiagnosticContext diag_ctx; - - private: + DiagnosticContext diag_ctx_; /*! \brief The module. */ IRModule module_; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index bf8d082ca2fe..726ec6c599dc 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -665,12 +665,12 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Type checked_type = solver_->Resolve(it->second.checked_type); if (checked_type.as() != nullptr) { - this->solver_->diag_ctx.Emit( + /*this->solver_->diag_ctx_.Emit( Diagnostic::Error(op->span) << "The type inference pass was unable to infer a type for this expression.\n" << "This usually occurs when an operator call is under constrained in some way," << " check other reported errors for hints of what may of happened."); - } + */} Expr new_e = post.defined() ? post : ExprMutator::VisitExpr_(op); // new_call and new_var's code is only going to be valid for VarNode/CallNode. From ce7c2ff4eb8307b044227b7e1a43e3eed7c92fd3 Mon Sep 17 00:00:00 2001 From: Raghav-Chakravarthy Date: Fri, 8 Oct 2021 16:52:13 -0400 Subject: [PATCH 3/7] removed debug logic --- src/relay/transforms/type_infer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 726ec6c599dc..336b70a51f24 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -665,12 +665,12 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Type checked_type = solver_->Resolve(it->second.checked_type); if (checked_type.as() != nullptr) { - /*this->solver_->diag_ctx_.Emit( + this->solver_->Emit( Diagnostic::Error(op->span) << "The type inference pass was unable to infer a type for this expression.\n" << "This usually occurs when an operator call is under constrained in some way," << " check other reported errors for hints of what may of happened."); - */} + } Expr new_e = post.defined() ? post : ExprMutator::VisitExpr_(op); // new_call and new_var's code is only going to be valid for VarNode/CallNode. From 678a72bb98873c312fa4bf8abc74ce8965faf76b Mon Sep 17 00:00:00 2001 From: Raghav-Chakravarthy Date: Fri, 8 Oct 2021 16:59:34 -0400 Subject: [PATCH 4/7] fixed plint issue --- src/relay/analysis/type_solver.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 878f91cb6693..3bde1a1e3746 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -94,7 +94,7 @@ class TypeSolver { * \brief Report a diagnostic. * \param diag The diagnostic to report. */ - void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }; + void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); } private: class OccursChecker; From 34f9a5077c658e1cff79e9603165602e527415fa Mon Sep 17 00:00:00 2001 From: Raghav-Chakravarthy Date: Fri, 8 Oct 2021 17:12:53 -0400 Subject: [PATCH 5/7] fixed building issue --- src/relay/analysis/type_solver.cc | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 965b196a4572..eca88cde049a 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -131,12 +131,11 @@ class TypeSolver::Unifier : public TypeFunctor { Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); if (!resolved.defined()) { - solver_->Emit( - Diagnostic::Error(this->span) - << "The Relay type checker is unable to show the following types match.\n" - << "In particular " - << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" - << PrettyPrint(rhs->resolved_type) << "`"); + solver_->Emit(Diagnostic::Error(this->span) + << "The Relay type checker is unable to show the following types match.\n" + << "In particular " + << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" + << PrettyPrint(rhs->resolved_type) << "`"); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -234,10 +233,10 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { this->solver_->Emit(Diagnostic::Error(this->span) - << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions"); + << "tensor type `" << PrettyPrint(tt1) << "` has " + << tt1->shape.size() << " dimensions, while `" + << PrettyPrint(tt2) << "` has " << tt2->shape.size() + << " dimensions"); return Type(nullptr); } From 01ac34ca63e8f539a3e2e3a998ac4d0ea66337d7 Mon Sep 17 00:00:00 2001 From: Raghav-Chakravarthy Date: Fri, 8 Oct 2021 17:28:39 -0400 Subject: [PATCH 6/7] fixed whitespace issue --- src/relay/analysis/type_solver.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index eca88cde049a..d954ea0b3160 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -132,9 +132,9 @@ class TypeSolver::Unifier : public TypeFunctor { if (!resolved.defined()) { solver_->Emit(Diagnostic::Error(this->span) - << "The Relay type checker is unable to show the following types match.\n" - << "In particular " - << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" + << "The Relay type checker is unable to show the following types match.\n" + << "In particular " + << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" << PrettyPrint(rhs->resolved_type) << "`"); return lhs->resolved_type; } else { From facf05567f5f064ff086851bc9a9c4585d6003a7 Mon Sep 17 00:00:00 2001 From: Raghav-Chakravarthy Date: Fri, 15 Oct 2021 15:43:28 -0400 Subject: [PATCH 7/7] fixed linting error in type_solver.cc --- src/relay/analysis/type_solver.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index d954ea0b3160..1421906a3bbb 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -233,10 +233,9 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { this->solver_->Emit(Diagnostic::Error(this->span) - << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions"); + << "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size() + << " dimensions, while `" << PrettyPrint(tt2) << "` has " + << tt2->shape.size() << " dimensions"); return Type(nullptr); }