Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type);

if (!resolved.defined()) {
solver_->diag_ctx_.Emit(
Diagnostic::Error(this->span)
<< "The Relay type checker is unable to show the following types match.\n"
<< "In particular "
<< "`" << PrettyPrint(lhs->resolved_type) << "` does not match `"
<< PrettyPrint(rhs->resolved_type) << "`");
solver_->Emit(Diagnostic::Error(this->span)
<< "The Relay type checker is unable to show the following types match.\n"
<< "In particular "
<< "`" << PrettyPrint(lhs->resolved_type) << "` does not match `"
<< PrettyPrint(rhs->resolved_type) << "`");
return lhs->resolved_type;
} else {
TypeNode* top = solver_->GetTypeNode(resolved);
Expand Down Expand Up @@ -233,11 +232,10 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {

tvm::Array<IndexExpr> shape;
if (tt1->shape.size() != tt2->shape.size()) {
this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span)
<< "tensor type `" << PrettyPrint(tt1) << "` has "
<< tt1->shape.size() << " dimensions, while `"
<< PrettyPrint(tt2) << "` has " << tt2->shape.size()
<< " dimensions");
this->solver_->Emit(Diagnostic::Error(this->span)
<< "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size()
<< " dimensions, while `" << PrettyPrint(tt2) << "` has "
<< tt2->shape.size() << " dimensions");
return Type(nullptr);
}

Expand Down Expand Up @@ -266,7 +264,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch)
<< " does not match " << std::get<2>(mismatch) << ".";
}
this->solver_->diag_ctx_.Emit(err);
this->solver_->Emit(err);
return Type(nullptr);
}

Expand Down Expand Up @@ -526,7 +524,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
// constructor
TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx)
: reporter_(make_object<Reporter>(this)),
current_func(current_func),
current_func_(current_func),
diag_ctx_(diag_ctx),
module_(diag_ctx->module) {
ICHECK(module_.defined());
Expand Down Expand Up @@ -618,7 +616,7 @@ bool TypeSolver::Solve() {

rnode->resolved = resolved;
} catch (const CompileError& err) {
this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what());
this->Emit(Diagnostic::Error(rnode->span) << err.what());
rnode->resolved = false;
} catch (const Error& e) {
ICHECK(false) << e.what();
Expand Down
8 changes: 2 additions & 6 deletions src/relay/analysis/type_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TypeSolver {
* \brief Report a diagnostic.
* \param diag The diagnostic to report.
*/
void EmitDiagnostic(const Diagnostic& diag);
void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }

private:
class OccursChecker;
Expand Down Expand Up @@ -176,13 +176,9 @@ class TypeSolver {
/*! \brief Reporter that reports back to self */
TypeReporter reporter_;
/*! \brief The global representing the current function. */
GlobalVar current_func;

public:
GlobalVar current_func_;
/*! \brief The diagnostic context. */
DiagnosticContext diag_ctx_;

private:
/*! \brief The module. */
IRModule module_;

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {
Type checked_type = solver_->Resolve(it->second.checked_type);

if (checked_type.as<IncompleteTypeNode>() != nullptr) {
this->solver_->diag_ctx_.Emit(
this->solver_->Emit(
Diagnostic::Error(op->span)
<< "The type inference pass was unable to infer a type for this expression.\n"
<< "This usually occurs when an operator call is under constrained in some way,"
Expand Down