Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/mlir/cxx/mlir/builtins_codegen-priv.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.


auto cxx::Codegen::ExpressionVisitor::codegenBuiltinDispatch(
cxx::CallExpressionAST* ast, cxx::BuiltinFunctionKind kind)
-> std::optional<ExpressionResult> {
Expand Down
2 changes: 1 addition & 1 deletion src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
}

void Codegen::enqueueFunctionBody(FunctionSymbol* symbol) {
auto target = symbol->canonical();
auto target = symbol->isSpecialization() ? symbol : symbol->canonical();
if (auto def = target->definition()) target = def;
if (!target->declaration()) return;
if (!enqueuedFunctions_.insert(target).second) return;
Expand Down
16 changes: 16 additions & 0 deletions src/parser/cxx/ast_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@

#include <cxx/ast_fwd.h>
#include <cxx/binder.h>
#include <cxx/diagnostic.h>
#include <cxx/diagnostics_client.h>
#include <cxx/names_fwd.h>
#include <cxx/token_fwd.h>

#include <functional>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -52,6 +55,11 @@ class [[nodiscard]] ASTRewriter {
static auto ensureCompleteClass(TranslationUnit* unit,
ClassSymbol* classSymbol) -> bool;

static void reportPendingInstantiationErrors(TranslationUnit* unit,
Symbol* primaryTemplate,
Symbol* instantiated,
SourceLocation instantiationLoc);

static auto substituteDefaultTypeId(
TranslationUnit* unit, TypeIdAST* typeId,
const std::vector<TemplateArgument>& templateArguments, int depth,
Expand Down Expand Up @@ -99,6 +107,10 @@ class [[nodiscard]] ASTRewriter {
auto arena() const -> Arena*;
auto binder() -> Binder& { return binder_; }

auto takeBodyErrors() -> std::vector<Diagnostic> {
return std::move(bodyErrors_);
}

auto restrictedToDeclarations() const -> bool;
void setRestrictedToDeclarations(bool restrictedToDeclarations);

Expand Down Expand Up @@ -192,6 +204,9 @@ class [[nodiscard]] ASTRewriter {
private:
auto rewriter() -> ASTRewriter* { return this; }

auto shouldCaptureBodyErrors() const -> bool;
void typeCheckAndCapture(std::function<void()> checkFn);

auto getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol*;

auto getTypeParameterPack(SpecifierAST* ast) -> ParameterPackSymbol*;
Expand All @@ -206,6 +221,7 @@ class [[nodiscard]] ASTRewriter {

TranslationUnit* unit_ = nullptr;
std::vector<TemplateArgument> templateArguments_;
std::vector<Diagnostic> bodyErrors_;
ParameterPackSymbol* parameterPack_ = nullptr;
std::optional<int> elementIndex_;
Binder binder_;
Expand Down
8 changes: 8 additions & 0 deletions src/parser/cxx/ast_rewriter_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cxx/decl_specs.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/type_checker.h>

namespace cxx {

Expand Down Expand Up @@ -427,6 +428,13 @@ auto ASTRewriter::DeclarationVisitor::operator()(
copy->rparenLoc = ast->rparenLoc;
copy->semicolonLoc = ast->semicolonLoc;

if (symbol_cast<FunctionSymbol>(binder()->instantiatingSymbol())) {
auto typeChecker = TypeChecker{translationUnit()};
typeChecker.setScope(binder()->scope());
typeChecker.setReportErrors(rewrite.shouldCaptureBodyErrors());
rewrite.typeCheckAndCapture([&] { typeChecker.check(copy); });
}

return copy;
}

Expand Down
17 changes: 16 additions & 1 deletion src/parser/cxx/ast_rewriter_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
#include <cxx/control.h>
#include <cxx/decl.h>
#include <cxx/decl_specs.h>
#include <cxx/dependent_types.h>
#include <cxx/literals.h>
#include <cxx/memory_layout.h>
#include <cxx/name_lookup.h>
#include <cxx/names.h>
#include <cxx/symbols.h>
Expand Down Expand Up @@ -596,6 +598,10 @@ auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
}

copy->symbol = *symbolPtr;
if (auto var = symbol_cast<VariableSymbol>(*symbolPtr)) {
if (var->type()) copy->type = var->type();
if (var->initializer()) return rewrite.expression(var->initializer());
}

} else if (copy->nestedNameSpecifier && copy->nestedNameSpecifier->symbol) {
binder()->qualifiedLookupIdExpression(copy);
Expand Down Expand Up @@ -1205,13 +1211,14 @@ auto ASTRewriter::ExpressionVisitor::operator()(MemberExpressionAST* ast)
copy->symbol = symbol;
copy->type = symbol->type();

// Propagate value category from the base expression.
if (auto field = symbol_cast<FieldSymbol>(symbol);
field && !field->isStatic()) {
copy->valueCategory = ast->valueCategory;
}
}
}
} else if (!isDependent(translationUnit(), objectType)) {
copy->type = nullptr;
}
}
}
Expand Down Expand Up @@ -1428,6 +1435,10 @@ auto ASTRewriter::ExpressionVisitor::operator()(SizeofExpressionAST* ast)
copy->sizeofLoc = ast->sizeofLoc;
copy->expression = rewrite.expression(ast->expression);

if (copy->expression && copy->expression->type) {
copy->value = control()->memoryLayout()->sizeOf(copy->expression->type);
}

return copy;
}

Expand All @@ -1442,6 +1453,10 @@ auto ASTRewriter::ExpressionVisitor::operator()(SizeofTypeExpressionAST* ast)
copy->typeId = rewrite.typeId(ast->typeId);
copy->rparenLoc = ast->rparenLoc;

if (copy->typeId && copy->typeId->type) {
copy->value = control()->memoryLayout()->sizeOf(copy->typeId->type);
}

return copy;
}

Expand Down
88 changes: 52 additions & 36 deletions src/parser/cxx/ast_rewriter_instantiate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ struct Instantiate {
auto operator()(Symbol*) -> Symbol* { return nullptr; }
};

auto isPrimaryTemplate(const std::vector<TemplateArgument>& templateArguments)
-> bool {
[[nodiscard]] auto isPrimaryTemplate(
const std::vector<TemplateArgument>& templateArguments) -> bool {
int expected = 0;
for (const auto& arg : templateArguments) {
if (!std::holds_alternative<Symbol*>(arg)) return false;
Expand Down Expand Up @@ -238,7 +238,8 @@ auto isPrimaryTemplate(const std::vector<TemplateArgument>& templateArguments)
return true;
}

auto templateParameterCount(TemplateDeclarationAST* templateDecl) -> int {
[[nodiscard]] auto templateParameterCount(TemplateDeclarationAST* templateDecl)
-> int {
if (!templateDecl) return 0;
int count = 0;
for (auto parameter : ListView{templateDecl->templateParameterList}) {
Expand All @@ -248,28 +249,21 @@ auto templateParameterCount(TemplateDeclarationAST* templateDecl) -> int {
return count;
}

auto computeInstantiationClassName(
[[nodiscard]] auto computeInstantiationClassName(
TranslationUnit* unit, Symbol* primaryTemplate,
const std::vector<TemplateArgument>& templateArguments) -> std::string {
if (!primaryTemplate) return "template";
return to_string(unit->control()->getTemplateId(primaryTemplate->name(),
templateArguments));
}

struct CapturingDiagnosticsClient final : DiagnosticsClient {
DiagnosticsClient* parent = nullptr;
std::vector<Diagnostic> diagnostics;

explicit CapturingDiagnosticsClient(DiagnosticsClient* parent)
: parent(parent) {}

void report(const Diagnostic& diagnostic) override {
diagnostics.push_back(diagnostic);
if (parent) parent->report(diagnostic);
}
};
[[nodiscard]] auto instantiationLabel(Symbol* symbol) -> std::string_view {
return symbol_cast<FunctionSymbol>(symbol)
? "function template specialization"
: "template class";
}

auto findMutableSpecialization(Symbol* primary, Symbol* spec)
[[nodiscard]] auto findMutableSpecialization(Symbol* primary, Symbol* spec)
-> TemplateSpecialization* {
if (!primary || !spec) return nullptr;
auto search = [spec](auto sym) -> TemplateSpecialization* {
Expand Down Expand Up @@ -309,6 +303,25 @@ auto ASTRewriter::substituteDefaultTypeId(
return rewriter.typeId(typeId);
}

void ASTRewriter::reportPendingInstantiationErrors(
TranslationUnit* unit, Symbol* primaryTemplate, Symbol* instantiated,
SourceLocation instantiationLoc) {
if (!primaryTemplate || !instantiated || !instantiationLoc) return;
if (auto spec = findMutableSpecialization(primaryTemplate, instantiated)) {
if (!spec->instantiationErrors.empty()) {
for (auto& diag : spec->instantiationErrors)
unit->diagnosticsClient()->report(diag);
spec->instantiationErrors.clear();
auto name =
computeInstantiationClassName(unit, primaryTemplate, spec->arguments);
auto label = instantiationLabel(primaryTemplate);
unit->note(instantiationLoc,
std::format("in instantiation of {} '{}' requested here",
label, name));
}
}
}

auto ASTRewriter::instantiate(TranslationUnit* unit,
List<TemplateArgumentAST*>* templateArgumentList,
Symbol* symbol, SourceLocation instantiationLoc,
Expand Down Expand Up @@ -354,24 +367,16 @@ auto ASTRewriter::instantiate(TranslationUnit* unit,
if (auto cached = visit(GetSpecialization{templateArguments}, symbol)) {
auto cachedClass = symbol_cast<ClassSymbol>(cached);
if (!cachedClass) {
if (!sfinaeContext)
reportPendingInstantiationErrors(unit, symbol, cached,
instantiationLoc);
if (savedDiagClient) (void)unit->changeDiagnosticsClient(savedDiagClient);
return cached;
}
if (cachedClass->declaration()) {
if (!sfinaeContext && instantiationLoc) {
if (auto spec = findMutableSpecialization(symbol, cached)) {
if (!spec->instantiationErrors.empty()) {
for (auto& diag : spec->instantiationErrors)
unit->diagnosticsClient()->report(diag);
auto className =
computeInstantiationClassName(unit, symbol, templateArguments);
unit->note(instantiationLoc,
std::format("in instantiation of template class '{}' "
"requested here",
className));
}
}
}
if (!sfinaeContext)
reportPendingInstantiationErrors(unit, symbol, cached,
instantiationLoc);
if (savedDiagClient) (void)unit->changeDiagnosticsClient(savedDiagClient);
return cached;
}
Expand Down Expand Up @@ -419,6 +424,12 @@ auto ASTRewriter::instantiate(TranslationUnit* unit,
auto instance = visit(Instantiate{rewriter}, symbol);
(void)unit->changeDiagnosticsClient(savedDiagClient);
if (sfinaeClient->hadError) return nullptr;
auto bodyErrors = rewriter.takeBodyErrors();
if (!bodyErrors.empty() && instance) {
if (auto spec = findMutableSpecialization(symbol, instance)) {
spec->instantiationErrors = std::move(bodyErrors);
}
}
return instance;
}

Expand All @@ -429,17 +440,22 @@ auto ASTRewriter::instantiate(TranslationUnit* unit,

(void)unit->changeDiagnosticsClient(capturing.parent);

auto bodyErrors = rewriter.takeBodyErrors();
capturing.diagnostics.insert(capturing.diagnostics.end(),
std::make_move_iterator(bodyErrors.begin()),
std::make_move_iterator(bodyErrors.end()));

if (!capturing.diagnostics.empty()) {
if (auto spec = findMutableSpecialization(symbol, instantiatedSymbol)) {
spec->instantiationErrors = capturing.diagnostics;
spec->instantiationErrors = std::move(capturing.diagnostics);
}
if (instantiationLoc) {
auto className =
auto name =
computeInstantiationClassName(unit, symbol, templateArguments);
auto label = instantiationLabel(symbol);
unit->note(instantiationLoc,
std::format("in instantiation of template class '{}' "
"requested here",
className));
std::format("in instantiation of {} '{}' requested here",
label, name));
}
}

Expand Down
26 changes: 25 additions & 1 deletion src/parser/cxx/ast_rewriter_requires.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,32 @@
// cxx
#include <cxx/ast.h>
#include <cxx/ast_interpreter.h>
#include <cxx/dependent_types.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/type_checker.h>

namespace cxx {

auto ASTRewriter::shouldCaptureBodyErrors() const -> bool {
return symbol_cast<FunctionSymbol>(binder_.instantiatingSymbol()) &&
binder_.reportErrors();
}

void ASTRewriter::typeCheckAndCapture(std::function<void()> checkFn) {
if (shouldCaptureBodyErrors()) {
CapturingDiagnosticsClient capture;
auto saved = unit_->changeDiagnosticsClient(&capture);
checkFn();
(void)unit_->changeDiagnosticsClient(saved);
bodyErrors_.insert(bodyErrors_.end(),
std::make_move_iterator(capture.diagnostics.begin()),
std::make_move_iterator(capture.diagnostics.end()));
} else {
checkFn();
}
}

auto ASTRewriter::checkRequiresClause(
TranslationUnit* unit, Symbol* symbol, RequiresClauseAST* clause,
const std::vector<TemplateArgument>& templateArguments, int depth) -> bool {
Expand All @@ -52,9 +72,13 @@ auto ASTRewriter::checkRequiresClause(
}

void ASTRewriter::check(ExpressionAST* ast) {
if (!ast) return;
if (isDependent(unit_, ast)) return;

auto typeChecker = TypeChecker{unit_};
typeChecker.setScope(binder_.scope());
typeChecker.check(ast);
typeChecker.setReportErrors(shouldCaptureBodyErrors());
typeCheckAndCapture([&] { typeChecker.check(ast); });
}

} // namespace cxx
Loading
Loading