diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 22e2e9a71040..1421906a3bbb 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -131,12 +131,11 @@ class TypeSolver::Unifier : public TypeFunctor { Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); if (!resolved.defined()) { - solver_->diag_ctx_.Emit( - Diagnostic::Error(this->span) - << "The Relay type checker is unable to show the following types match.\n" - << "In particular " - << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" - << PrettyPrint(rhs->resolved_type) << "`"); + solver_->Emit(Diagnostic::Error(this->span) + << "The Relay type checker is unable to show the following types match.\n" + << "In particular " + << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" + << PrettyPrint(rhs->resolved_type) << "`"); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -233,11 +232,10 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span) - << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions"); + this->solver_->Emit(Diagnostic::Error(this->span) + << "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size() + << " dimensions, while `" << PrettyPrint(tt2) << "` has " + << tt2->shape.size() << " dimensions"); return Type(nullptr); } @@ -266,7 +264,7 @@ class TypeSolver::Unifier : public TypeFunctor { err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch) << "."; } - this->solver_->diag_ctx_.Emit(err); + this->solver_->Emit(err); return Type(nullptr); } @@ -526,7 +524,7 @@ class TypeSolver::Merger : public TypeFunctor { // constructor TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) : reporter_(make_object(this)), - current_func(current_func), + current_func_(current_func), diag_ctx_(diag_ctx), module_(diag_ctx->module) { ICHECK(module_.defined()); @@ -618,7 +616,7 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const CompileError& err) { - this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what()); + this->Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; } catch (const Error& e) { ICHECK(false) << e.what(); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 56cea60ceeda..3bde1a1e3746 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -94,7 +94,7 @@ class TypeSolver { * \brief Report a diagnostic. * \param diag The diagnostic to report. */ - void EmitDiagnostic(const Diagnostic& diag); + void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); } private: class OccursChecker; @@ -176,13 +176,9 @@ class TypeSolver { /*! \brief Reporter that reports back to self */ TypeReporter reporter_; /*! \brief The global representing the current function. */ - GlobalVar current_func; - - public: + GlobalVar current_func_; /*! \brief The diagnostic context. */ DiagnosticContext diag_ctx_; - - private: /*! \brief The module. */ IRModule module_; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 6c2371716b16..336b70a51f24 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -665,7 +665,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Type checked_type = solver_->Resolve(it->second.checked_type); if (checked_type.as() != nullptr) { - this->solver_->diag_ctx_.Emit( + this->solver_->Emit( Diagnostic::Error(op->span) << "The type inference pass was unable to infer a type for this expression.\n" << "This usually occurs when an operator call is under constrained in some way,"