diff --git a/maldoca/astgen/BUILD b/maldoca/astgen/BUILD new file mode 100644 index 0000000..c2d14c8 --- /dev/null +++ b/maldoca/astgen/BUILD @@ -0,0 +1,169 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + + + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//:__subpackages__", + ], +) + +proto_library( + name = "ast_def_proto", + srcs = ["ast_def.proto"], + deps = [":type_proto"], +) + +cc_proto_library( + name = "ast_def_cc_proto", + deps = ["ast_def_proto"], +) + +cc_library( + name = "ast_def", + srcs = [ + "ast_def.cc", + "type.cc", + ], + hdrs = [ + "ast_def.h", + "type.h", + ], + deps = [ + ":ast_def_cc_proto", + ":symbol", + ":type_cc_proto", + "//maldoca/base:status", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:btree", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/functional:bind_front", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:span", + ], +) + +cc_library( + name = "ast_gen", + srcs = ["ast_gen.cc"], + hdrs = ["ast_gen.h"], + deps = [ + ":ast_def", + ":ast_def_cc_proto", + ":symbol", + "//maldoca/base:path", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/container:btree", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:str_format", + "@com_google_protobuf//src/google/protobuf/io", + "@com_google_protobuf//src/google/protobuf/io:printer", + ], +) + +cc_binary( + name = "ast_gen_main", + srcs = ["ast_gen_main.cc"], + deps = [ + ":ast_def", + ":ast_def_cc_proto", + ":ast_gen", + "//maldoca/base:filesystem", + "//maldoca/base:path", + "//maldoca/base:status", + "@abseil-cpp//absl/flags:flag", + "@abseil-cpp//absl/strings", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + deps = [ + ":ast_def", + ":ast_def_cc_proto", + ":ast_gen", + ":symbol", + "//maldoca/base:filesystem", + "//maldoca/base/testing:status_matchers", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/strings", + "@com_google_protobuf//src/google/protobuf/io", + "@googletest//:gtest_main", + ], +) + +proto_library( + name = "type_proto", + srcs = ["type.proto"], +) + +cc_proto_library( + name = "type_cc_proto", + deps = ["type_proto"], +) + +cc_test( + name = "type_test", + srcs = ["type_test.cc"], + deps = [ + ":ast_def", + ":ast_def_cc_proto", + ":type_cc_proto", + "//maldoca/base:filesystem", + "//maldoca/base/testing:status_matchers", + "@abseil-cpp//absl/container:flat_hash_map", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "symbol", + srcs = ["symbol.cc"], + hdrs = ["symbol.h"], + deps = [ + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/strings", + ], +) + +cc_test( + name = "symbol_test", + srcs = ["symbol_test.cc"], + deps = [ + ":symbol", + "@googletest//:gtest_main", + ], +) diff --git a/maldoca/astgen/OWNERS b/maldoca/astgen/OWNERS new file mode 100644 index 0000000..d163000 --- /dev/null +++ b/maldoca/astgen/OWNERS @@ -0,0 +1,2 @@ +# LLVM owners for API updates. (see go/mlir-sla). +suggest-reviewers-ignore: file://depot/google3/llvm/OWNERS diff --git a/maldoca/astgen/ast_def.cc b/maldoca/astgen/ast_def.cc new file mode 100644 index 0000000..573ed0a --- /dev/null +++ b/maldoca/astgen/ast_def.cc @@ -0,0 +1,571 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/ast_def.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/symbol.h" +#include "maldoca/astgen/type.h" +#include "maldoca/base/status_macros.h" + +namespace maldoca { +namespace { + +// NOTE: DO NOT USE this internal function directly. +// Use TopologicalSortDependencies(node, get_dependencies) instead. +// +// - get_dependencies: A function that returns the dependencies of a node. +// +// - pre_order_visited: Internal state. +// See comments within the function. +// +// - sorted_dependencies: Output vector. +// Topologically sorted dependencies are appended here. +// See comments within the function. +void TopologicalSortDependencies( + NodeDef *node, + std::function(NodeDef *)> get_dependencies, + absl::flat_hash_set *pre_order_visited, + std::vector *sorted_dependencies) { + // We run a DFS to perform topological sort. + // + // We maintain two sets: + // - sorted_dependencies: The result vector being constructed. + // - pre_order_visited: nodes in the recursion stack. + // + // Each node in the graph can either: + // 1) Appear in neither. + // 2) Appear in "pre_order_visited"; + // 3) Appear in "sorted_dependencies". + // + // Each node is inserted to "pre_order_visited" pre-order; moved to + // "sorted_dependencies" post-order. If a node is already in + // "sorted_dependencies", skip this node (typical DFS); If a node is already + // in "pre_order_visited", this means the graph has cycle! + std::vector dependencies = get_dependencies(node); + for (NodeDef *dependency : dependencies) { + CHECK(!pre_order_visited->contains(dependency)) << "Graph has cycle!"; + if (absl::c_linear_search(*sorted_dependencies, dependency)) { + continue; + } + + pre_order_visited->insert(dependency); + TopologicalSortDependencies(dependency, get_dependencies, pre_order_visited, + sorted_dependencies); + pre_order_visited->erase(dependency); + sorted_dependencies->push_back(dependency); + } +} + +// Performs a topological sort on all the (transitive) dependencies of `node`. +// +// For example: (A <: B means A depends on B) +// +// Input graph: +// CatDog <: Cat, Dog +// Cat <: Animal +// Dog <: Animal +// +// TopologicalSortDependencies(CatDog): +// Animal, Cat, Dog +// +// Note: We use the original order of dependencies to break tie. For example, +// Cat appears before Dog and this is preserved. +std::vector TopologicalSortDependencies( + NodeDef *node, + std::function(NodeDef *)> get_dependencies) { + absl::flat_hash_set pre_order_visited; + std::vector sorted_dependencies; + TopologicalSortDependencies(node, get_dependencies, &pre_order_visited, + &sorted_dependencies); + return sorted_dependencies; +} + +// Gets the dependency nodes of a given type. +// +// In the generated C++ code, these nodes must be defined before the type is +// used. +// +// - nodes: All nodes in the AST. +// - dependencies: Output vector. Dependencies are appended to this vector. +void GetDependencies( + const Type &type, + const absl::flat_hash_map> &nodes, + std::vector *dependencies) { + switch (type.kind()) { + case TypeKind::kBuiltin: + case TypeKind::kEnum: + // No dependencies. + break; + + case TypeKind::kClass: { + const auto &class_type = static_cast(type); + auto it = nodes.find(class_type.name().ToPascalCase()); + CHECK(it != nodes.end()) + << class_type.name().ToPascalCase() << " undefined."; + dependencies->push_back(it->second.get()); + break; + } + + case TypeKind::kList: { + const auto &list_type = static_cast(type); + GetDependencies(list_type.element_type(), nodes, dependencies); + break; + } + + case TypeKind::kVariant: { + const auto &variant_type = static_cast(type); + for (const auto &type : variant_type.types()) { + GetDependencies(*type, nodes, dependencies); + } + break; + } + } +} + +} // namespace + +absl::StatusOr EnumMemberDef::FromEnumMemberDefPb( + const EnumMemberDefPb &member_pb) { + Symbol name{member_pb.name()}; + if (name.ToPascalCase() != member_pb.name()) { + return absl::InvalidArgumentError(absl::StrCat( + "The enum member name '", member_pb.name(), "' is not in PascalCase.")); + } + + return EnumMemberDef{std::move(name), std::move(member_pb.string_value())}; +} + +absl::StatusOr EnumDef::FromEnumDefPb(const EnumDefPb &enum_pb) { + Symbol name{enum_pb.name()}; + if (name.ToPascalCase() != enum_pb.name()) { + return absl::InvalidArgumentError(absl::StrCat( + "The enum type name '", enum_pb.name(), "' is not in PascalCase.")); + } + + std::vector members; + for (const EnumMemberDefPb &member_pb : enum_pb.members()) { + MALDOCA_ASSIGN_OR_RETURN(auto member, + EnumMemberDef::FromEnumMemberDefPb(member_pb)); + members.push_back(std::move(member)); + } + + return EnumDef{std::move(name), std::move(members)}; +} + +absl::StatusOr FieldDef::FromFieldDefPb(const FieldDefPb &field_pb, + absl::string_view lang_name) { + FieldDef field; + field.name_ = Symbol(field_pb.name()); + + // Check that the name is in camelCase. + if (field.name().ToCamelCase() != field_pb.name()) { + return absl::InvalidArgumentError( + absl::StrCat("Field '", field_pb.name(), "' is not in camelCase.")); + } + + MALDOCA_ASSIGN_OR_RETURN(field.type_, FromTypePb(field_pb.type(), lang_name)); + + if (field_pb.optionalness() == OPTIONALNESS_UNSPECIFIED) { + return absl::InvalidArgumentError( + absl::StrCat("Field '", field_pb.name(), + "' has OPTIONALNESS_UNSPECIFIED. This should be a bug, as " + "the default value is already OPTIONALNESS_REQUIRED.")); + } + field.optionalness_ = field_pb.optionalness(); + + field.kind_ = field_pb.kind(); + field.ignore_in_ir_ = field_pb.ignore_in_ir(); + field.enclose_in_region_ = field_pb.enclose_in_region(); + + return field; +} + +std::optional NodeDef::ir_op_name(absl::string_view lang_name, + FieldKind kind) const { + // If there's a custom IR op name, return it. + if (ir_op_name_.has_value()) { + return Symbol(*ir_op_name_); + } + + // If any descendent has a custom IR op name, then we fallback to mlir::Value. + if (absl::c_any_of(descendants(), [](const NodeDef *descendent) { + return descendent->ir_op_name_.has_value(); + })) { + return std::nullopt; + } + + auto ir_name = absl::StrCat(lang_name, has_control_flow() ? "hir" : "ir"); + + Symbol result{ir_name}; + + result += name(); + + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Invalid FieldKind."; + case FIELD_KIND_ATTR: + result += "Attr"; + break; + case FIELD_KIND_RVAL: + case FIELD_KIND_STMT: + result += "Op"; + break; + case FIELD_KIND_LVAL: + result += "RefOp"; + break; + } + + if (!children().empty()) { + result += "Interface"; + } + + return result; +} + +std::optional NodeDef::ir_op_mnemonic(FieldKind kind) const { + // If there's a custom IR op name, give up (we won't need mnemonic since we + // won't generate an op). + if (ir_op_name_.has_value()) { + return std::nullopt; + } + + // If any descendent has a custom IR op name, then we fallback to mlir::Value. + if (absl::c_any_of(descendants(), [](const NodeDef *descendent) { + return descendent->ir_op_name_.has_value(); + })) { + return std::nullopt; + } + + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Invalid FieldKind."; + case FIELD_KIND_ATTR: + LOG(FATAL) << "Unsupported FieldKind: " << kind; + case FIELD_KIND_LVAL: + return Symbol(name()) + "ref"; + case FIELD_KIND_RVAL: + case FIELD_KIND_STMT: + return Symbol(name()); + } +} + +/*static*/ +absl::StatusOr AstDef::FromProto(const AstDefPb &pb) { + std::vector enum_defs; + for (const EnumDefPb &enum_def_pb : pb.enums()) { + MALDOCA_ASSIGN_OR_RETURN(EnumDef enum_def, + EnumDef::FromEnumDefPb(enum_def_pb)); + enum_defs.push_back(std::move(enum_def)); + } + + std::vector node_names; + absl::flat_hash_map> nodes; + + for (const NodeDefPb &node_pb : pb.nodes()) { + if (nodes.contains(node_pb.name())) { + return absl::InvalidArgumentError( + absl::StrCat(node_pb.name(), " already exists!")); + } + + auto node = absl::WrapUnique(new NodeDef()); + + node->name_ = node_pb.name(); + + if (node_pb.has_type()) { + node->type_ = node_pb.type(); + } + + for (const FieldDefPb &field_pb : node_pb.fields()) { + MALDOCA_ASSIGN_OR_RETURN( + FieldDef field, FieldDef::FromFieldDefPb(field_pb, pb.lang_name())); + node->fields_.push_back(std::move(field)); + } + + node->has_control_flow_ = node_pb.has_control_flow(); + + if (node_pb.has_ir_op_name()) { + node->ir_op_name_ = node_pb.ir_op_name(); + } + + node->should_generate_ir_op_ = node_pb.should_generate_ir_op(); + + node->has_fold_ = node_pb.has_fold(); + + for (auto kind : node_pb.kinds()) { + node->kinds_.push_back(static_cast(kind)); + } + + for (auto mlir_trait : node_pb.additional_mlir_traits()) { + node->additional_mlir_traits_.push_back( + static_cast(mlir_trait)); + } + + node_names.push_back(node_pb.name()); + nodes.emplace(node_pb.name(), std::move(node)); + } + + // Set parent pointers. + for (const NodeDefPb &node_pb : pb.nodes()) { + NodeDef &node = *nodes.at(node_pb.name()); + + for (absl::string_view parent_name : node_pb.parents()) { + auto it = nodes.find(parent_name); + if (it == nodes.end()) { + return absl::InvalidArgumentError( + absl::StrCat("Parent ", parent_name, " doesn't exist!")); + } + NodeDef *parent = it->second.get(); + node.parents_.push_back(parent); + } + } + + // For union types, create a node to represent each one and add that node as + // a parent of the specified types. + for (const UnionTypePb &union_type_pb : pb.union_types()) { + auto union_type_node = absl::WrapUnique(new NodeDef()); + union_type_node->name_ = union_type_pb.name(); + if (nodes.contains(union_type_pb.name())) { + return absl::InvalidArgumentError( + absl::StrCat(union_type_pb.name(), " already exists!")); + } + node_names.push_back(union_type_pb.name()); + nodes.emplace(union_type_pb.name(), std::move(union_type_node)); + } + + for (const UnionTypePb &union_type_pb : pb.union_types()) { + auto union_type_node = nodes.at(union_type_pb.name()).get(); + for (const std::string &type : union_type_pb.types()) { + auto child_node = nodes.find(type); + if (child_node == nodes.end()) { + return absl::InvalidArgumentError( + absl::StrCat("Union type ", union_type_pb.name(), ": member ", type, + " doesn't exist!")); + } + child_node->second->parents_.push_back(union_type_node); + } + + for (absl::string_view parent_name : union_type_pb.parents()) { + auto it = nodes.find(parent_name); + if (it == nodes.end()) { + return absl::InvalidArgumentError( + absl::StrCat("Parent ", parent_name, " doesn't exist!")); + } + NodeDef *parent = it->second.get(); + union_type_node->parents_.push_back(parent); + } + } + + // NOTE: In the code below, we traverse `node_names` instead of `nodes`. + // `node_names` preserves the original order of definitions. + // This makes sure that the algorithm is always deterministic. + + // Set ancestors vector. + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + node.ancestors_ = TopologicalSortDependencies( + &node, [](NodeDef *node) { return node->parents_; }); + } + + // Set aggregated_fields vector. + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + for (NodeDef *ancestor : node.ancestors_) { + for (FieldDef &field : ancestor->fields_) { + node.aggregated_fields_.push_back(&field); + } + } + for (FieldDef &field : node.fields_) { + node.aggregated_fields_.push_back(&field); + } + } + + // Set children vector. + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + for (NodeDef *parent : node.parents_) { + parent->children_.push_back(&node); + } + } + + // Set descendants vector. + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + node.descendants_ = TopologicalSortDependencies( + &node, [](NodeDef *node) { return node->children_; }); + } + + // Set leafs vector. + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + for (NodeDef *descendent : node.descendants_) { + if (!descendent->children().empty()) { + continue; + } + node.leafs_.push_back(descendent); + } + } + + // Set aggregated_kinds vector. + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + absl::btree_set aggregated_kinds; + + for (NodeDef *ancestor : node.ancestors_) { + for (FieldKind kind : ancestor->kinds_) { + aggregated_kinds.insert(kind); + } + } + for (FieldKind kind : node.kinds_) { + aggregated_kinds.insert(kind); + } + + node.aggregated_kinds_ = {aggregated_kinds.begin(), aggregated_kinds.end()}; + } + + // Set the aggregated_additional_mlir_traits vector. + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + absl::btree_set aggregated_additional_mlir_traits; + + for (NodeDef *ancestor : node.ancestors_) { + for (MlirTrait trait : ancestor->additional_mlir_traits()) { + aggregated_additional_mlir_traits.insert(trait); + } + } + for (MlirTrait trait : node.additional_mlir_traits()) { + aggregated_additional_mlir_traits.insert(trait); + } + + node.aggregated_additional_mlir_traits_ = { + aggregated_additional_mlir_traits.begin(), + aggregated_additional_mlir_traits.end(), + }; + } + + // Reorder the node definitions so that dependencies always come first. + std::vector topological_sorted_nodes; + absl::flat_hash_set preorder_visited_nodes; + for (const std::string &name : node_names) { + NodeDef &node = *nodes.at(name); + + TopologicalSortDependencies( + &node, + [&nodes](NodeDef *node) { + std::vector dependencies; + dependencies.insert(dependencies.end(), node->parents_.begin(), + node->parents_.end()); + for (const FieldDef &field : node->fields()) { + GetDependencies(field.type(), nodes, &dependencies); + } + return dependencies; + }, + &preorder_visited_nodes, &topological_sorted_nodes); + if (!absl::c_linear_search(topological_sorted_nodes, &node)) { + topological_sorted_nodes.push_back(&node); + } + } + + // For each root node, add an enum field to represent the leaf type. + for (NodeDef *node : topological_sorted_nodes) { + if (!node->parents().empty()) { + continue; + } + if (node->children().empty()) { + continue; + } + + std::vector type_enum_members; + for (const NodeDef *leaf : node->leafs()) { + EnumMemberDef member{Symbol(leaf->name()), leaf->name()}; + type_enum_members.push_back(std::move(member)); + } + + node->node_type_enum_ = EnumDef{ + Symbol{node->name()} + "Type", + std::move(type_enum_members), + }; + } + + // For each ClassType, if it resolves to a NodeDef, store a reference to it. + for (NodeDef *node : topological_sorted_nodes) { + for (FieldDef &field : node->fields_) { + ResolveClassType(field.type(), topological_sorted_nodes); + } + } + + return AstDef{pb.lang_name(), std::move(enum_defs), std::move(node_names), + std::move(nodes), std::move(topological_sorted_nodes)}; +} + +void AstDef::ResolveClassType( + Type &type, absl::Span topological_sorted_nodes) { + switch (type.kind()) { + case TypeKind::kBuiltin: + case TypeKind::kEnum: { + break; + } + case TypeKind::kClass: { + auto &class_type = static_cast(type); + for (const NodeDef *node : topological_sorted_nodes) { + if (node->name() == class_type.name().ToPascalCase()) { + LOG(INFO) << "Resolved class " << node->name(); + class_type.node_def_ = node; + break; + } + } + break; + } + case TypeKind::kList: { + auto &list_type = static_cast(type); + ResolveClassType(list_type.element_type(), topological_sorted_nodes); + break; + } + case TypeKind::kVariant: { + auto &variant_type = static_cast(type); + for (auto &type : variant_type.types()) { + ResolveClassType(*type, topological_sorted_nodes); + } + break; + } + } +} + +} // namespace maldoca diff --git a/maldoca/astgen/ast_def.h b/maldoca/astgen/ast_def.h new file mode 100644 index 0000000..eda8354 --- /dev/null +++ b/maldoca/astgen/ast_def.h @@ -0,0 +1,359 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_AST_DEF_H_ +#define MALDOCA_ASTGEN_AST_DEF_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/symbol.h" +#include "maldoca/astgen/type.h" +#include "maldoca/astgen/type.pb.h" + +namespace maldoca { + +class FieldDef; +class NodeDef; +class AstDef; + +class EnumMemberDef { + public: + explicit EnumMemberDef(Symbol name, std::string string_value) + : name_(std::move(name)), string_value_(std::move(string_value)) {} + + static absl::StatusOr FromEnumMemberDefPb( + const EnumMemberDefPb &member_pb); + + const Symbol &name() const { return name_; } + const std::string &string_value() const { return string_value_; } + + private: + Symbol name_; + std::string string_value_; +}; + +class EnumDef { + public: + explicit EnumDef(Symbol name, std::vector members) + : name_(std::move(name)), members_(std::move(members)) {} + + static absl::StatusOr FromEnumDefPb(const EnumDefPb &enum_pb); + + const Symbol &name() const { return name_; } + absl::Span members() const { return members_; } + + private: + Symbol name_; + std::vector members_; +}; + +// Definition of a field in a class. +class FieldDef { + public: + static absl::StatusOr FromFieldDefPb(const FieldDefPb &field_pb, + absl::string_view lang_name); + + const Symbol &name() const { return name_; } + Optionalness optionalness() const { return optionalness_; } + const Type &type() const { return *type_; } + Type &type() { return *type_; } + FieldKind kind() const { return kind_; } + bool ignore_in_ir() const { return ignore_in_ir_; } + bool enclose_in_region() const { return enclose_in_region_; } + + private: + // Only allows creation from proto. + FieldDef() = default; + + Symbol name_; + Optionalness optionalness_; + std::unique_ptr type_; + FieldKind kind_; + bool ignore_in_ir_; + bool enclose_in_region_; +}; + +// Definition of an AST node type. +// This corresponds to a C++ class. +class NodeDef { + public: + // Class name. + const std::string &name() const { return name_; } + + // Type kind enum. + // + // In the JavaScript object version of the AST, a special "type" string + // represents the kind of the node. + // + // interface BinaryExpression <: Expression { + // type: "BinaryExpression"; <============ This field. + // operator: BinaryOperator; + // left: Expression | PrivateName; + // right: Expression; + // } + // + // The "type" string only has a concrete value in leaf types. + // + // interface Expression <: Node { } <======= No "type" value defined. + // + // The existence of a concrete "type" value suggests that this is a leaf type. + std::optional type() const { return type_; } + + // Fields in the class. + // + // This doesn't include fields in base classes. + absl::Span fields() const { return fields_; } + + // The classes that this derives from. + // + // For example: + // + // interface Identifier <: Expression, Pattern { + // type: "Identifier"; + // name: string; + // } + // + // parents = { Expression, Pattern } + absl::Span parents() const { return parents_; } + + // Topologically sorted: base comes before derived. Use the original + // definition order to break tie. + // + // For example: + // interface Node; + // interface Expression <: Node + // interface Pattern <: Node + // interface Identifier <: Expression, Pattern + // + // ancestors: Node, Expression, Pattern + absl::Span ancestors() const { return ancestors_; } + + // All fields, including those defined by ancestors. + absl::Span aggregated_fields() const { + return aggregated_fields_; + } + + // Direct children of this class. + absl::Span children() const { return children_; } + + // All types that directly or indirectly inherit this class. + absl::Span descendants() const { return descendants_; } + + // All descendants that are leaf classes. + absl::Span leafs() const { return leafs_; } + + std::optional node_type_enum() const { + if (node_type_enum_.has_value()) { + return &node_type_enum_.value(); + } else { + return std::nullopt; + } + } + + // Whether an IR op should be automatically generated. + // If false, the op is expected to be manually written. + bool should_generate_ir_op() const { return should_generate_ir_op_; } + + // The allowed FieldKinds for this node. Does not include those specified in + // ancestors. + // + // For the meaning of FieldKind, see comments for the proto definition. + // + // For example: + // + // Expression { + // kinds: FIELD_KIND_RVAL + // } + // Identifier <: Expression { + // kinds: FIELD_KIND_LVAL + // } + // + // For Identifier, kinds() returns [FIELD_KIND_LVAL]. + // + // In practice, you most likely want aggregate_kinds(), which returns + // [FIELD_KIND_RVAL, FIELD_KIND_LVAL]. + absl::Span kinds() const { return kinds_; } + + // The allowed FieldKinds for this node. Includes those specified in + // ancestors. + // + // For the meaning of FieldKind, see comments for the proto definition. + // + // For example: + // + // Expression { + // kinds: FIELD_KIND_RVAL + // } + // Identifier <: Expression { + // kinds: FIELD_KIND_LVAL + // } + // + // For Identifier, aggregate_kinds() returns + // [FIELD_KIND_RVAL, FIELD_KIND_LVAL]. + // + // You may also use kinds(), which returns [FIELD_KIND_RVAL]. However, + // aggregated_kinds() is usually the one you want. + absl::Span aggregated_kinds() const { + return aggregated_kinds_; + } + + // Whether this node has control-flow-related information. + // + // A node is considered to have control-flow-related information if it + // contains some branch semantics. + // + // Example: IfStatement, BreakStatement. + // + // When this is true, we define two ops, one in HIR (high-level IR), one in + // LIR (low-level IR). + bool has_control_flow() const { return has_control_flow_; } + + // The MLIR op name (C++ class name). + // + // : + // has_control_flow: hir + // !has_control_flow: ir + // + // - Non-leaf type: "OpInterface" + // - Leaf type: + // - RVal: "Op" + // - LVal: "RefOp" + // + // If a custom IR op name is specified (NodeDefPb::ir_op_name), returns that + // instead. + // + // If a custom IR op name is specified for any of the descendants, returns + // nullopt. + std::optional ir_op_name(absl::string_view lang_name, + FieldKind kind) const; + + // The stringified MLIR op name (without dialect name). + // + // - Non-leaf type: N/A + // - Leaf type: + // - RVal: "" + // - LVal: "_ref" + // + // If a custom IR op name is specified, returns nullopt. + // + // If a custom IR op name is specified for any of the descendants, returns + // nullopt. + std::optional ir_op_mnemonic(FieldKind kind) const; + + bool has_fold() const { return has_fold_; } + + // Additional MLIR traits to add to the op definition in ODS. + absl::Span additional_mlir_traits() const { + return additional_mlir_traits_; + } + + // Additional MLIR traits to add to the op definition in ODS, including those + // from ancestors. + absl::Span aggregated_additional_mlir_traits() const { + return aggregated_additional_mlir_traits_; + } + + private: + // Only AstDef can create NodeDefs. + NodeDef() = default; + + std::string name_; + std::optional type_; + std::vector fields_; + std::vector parents_; + std::vector ancestors_; + std::vector aggregated_fields_; + std::vector children_; + std::vector descendants_; + std::vector leafs_; + std::optional node_type_enum_; + bool should_generate_ir_op_; + std::vector kinds_; + std::vector aggregated_kinds_; + bool has_control_flow_; + std::optional ir_op_name_; + bool has_fold_; + std::vector additional_mlir_traits_; + std::vector aggregated_additional_mlir_traits_; + + friend class AstDef; +}; + +// Definition of an AST. +class AstDef { + public: + // Creates an AST definition from a proto. + // Also checks the validity of the proto. + static absl::StatusOr FromProto(const AstDefPb &pb); + + absl::string_view lang_name() const { return lang_name_; } + + absl::Span enum_defs() const { return enum_defs_; } + + // Names of the nodes in the original order. + absl::Span node_names() const { return node_names_; } + + // Node name => node definition. + const absl::flat_hash_map> &nodes() + const { + return nodes_; + } + + // Nodes listed in topological order. + // This order ensures that dependencies (parent classes, field types) are + // defined before each class. + absl::Span topological_sorted_nodes() const { + return topological_sorted_nodes_; + } + + private: + explicit AstDef( + std::string lang_name, std::vector enum_defs, + std::vector node_names, + absl::flat_hash_map> nodes, + std::vector topological_sorted_nodes) + : lang_name_(std::move(lang_name)), + enum_defs_(std::move(enum_defs)), + node_names_(std::move(node_names)), + nodes_(std::move(nodes)), + topological_sorted_nodes_(std::move(topological_sorted_nodes)) { + LOG(INFO) << "Created AstDef. node_names:"; + for (const std::string &node_name : node_names_) { + LOG(INFO) << " " << node_name; + } + } + + std::string lang_name_; + std::vector enum_defs_; + std::vector node_names_; + absl::flat_hash_map> nodes_; + std::vector topological_sorted_nodes_; + + static void ResolveClassType( + Type &type, absl::Span topological_sorted_nodes); +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_AST_DEF_H_ diff --git a/maldoca/astgen/ast_def.proto b/maldoca/astgen/ast_def.proto new file mode 100644 index 0000000..2d169cd --- /dev/null +++ b/maldoca/astgen/ast_def.proto @@ -0,0 +1,352 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package maldoca; + +import "maldoca/astgen/type.proto"; + +option java_multiple_files = true; + + +// ========= +// FieldKind +// ========= +// +// Each field in an AST node has a "kind". This is because different kinds of +// AST nodes lead to different forms of IR ops. +// +// Example: +// +// For the assignment expression "x = y", lhs and rhs have the same AST node +// type: Identifier. +// +// AST: +// ``` +// AssignmentExpression { +// lhs: Identifier {"x"} +// rhs: Identifier {"y"} +// } +// ``` +// +// However, lhs is an lvalue, but rhs is an rvalue. Therefore, they lower to +// different IR ops: +// +// ``` +// %lhs = identifier_ref {"x"} +// %rhs = identifier {"y"} +// assignment_expression (%lhs, %rhs) +// ``` +// +// We can see that "x" lowers to an "identifier_ref" op, but "y" lowers to an +// "identifier" op. +// +// In order to support this, we need to: +// +// - Specify that Assignment::lhs has an "LVAL" kind. +// - Specify that Assignment::rhs has an "RVAL" kind. +// - Specify that the Identifier node type can be both LVAL and RVAL. +enum FieldKind { + FIELD_KIND_UNSPECIFIED = 0; + + // This field is an attribute (has builtin type). + FIELD_KIND_ATTR = 1; + + // This field is an rvalue expression. If the expression type has an LVAL + // kind, we need to create an additional "load" op. + FIELD_KIND_RVAL = 2; + + // This field is an lvalue expression. The expression type must have an LVAL + // kind. + FIELD_KIND_LVAL = 3; + + // This field is a statement. + FIELD_KIND_STMT = 4; +} + +message EnumMemberDefPb { + optional string name = 1; + optional string string_value = 2; +} + +message EnumDefPb { + optional string name = 1; + repeated EnumMemberDefPb members = 2; +} + +// This is an unfortunate Babel-specific detail. Since Babel's AST is a JSON +// object, maybe_null and maybe_undefined are different cases. +// +// - "field: string | null" +// This means the field must exist, but could be null. +// +// Example: The specification for "AwaitExpression" is this: +// +// ``` +// interface AwaitExpression <: Expression { +// type: "AwaitExpression"; +// argument: Expression | null; +// } +// ``` +// +// In the above example, "argument" is maybe_null, which means that in the +// JSON object, the entry must exist, but the value can be null: +// +// ``` +// { +// "type": "AwaitExpression", +// "argument": null // <---- value is null. +// } +// ``` +// +// - "field?: string" +// This means the field might not exist, but if it does, it must be non-null. +// +// Example: The specification for "CatchClause" is this: +// +// ``` +// interface CatchClause <: Node { +// type: "CatchClause"; +// param?: Pattern; +// body: BlockStatement; +// } +// ``` +// +// In the above example, "param" is maybe_undefined, which means that in the +// JSON object, the entry might not exist at all: +// +// ``` +// { +// "type": "CatchClause", +// // <---- "param" doesn't exist. +// "body": { ... } +// } +// ``` +// +// However, it appears that in Babel's AST, there is no field that is both +// maybe_null and maybe_undefined (in other words, there is no such thing as +// "field?: string | null"). Therefore, both cases are represented as +// "std::optional". +enum Optionalness { + // No semantic meaning. + // go/protodosdonts#do-include-an-unspecified-value-in-an-enum + OPTIONALNESS_UNSPECIFIED = 0; + + // Field must exist and be non-null. + OPTIONALNESS_REQUIRED = 1; + + // Field must exist, but might be null. + OPTIONALNESS_MAYBE_NULL = 2; + + // Field might not exist, but when it exists, it must be non-null. + OPTIONALNESS_MAYBE_UNDEFINED = 3; +} + +// Definition of a field in an AST node. +message FieldDefPb { + // Name of the field. + // + // Must be camelCase. + optional string name = 1; + + // The optionalness of a field - whether it might be null or undefined. + optional Optionalness optionalness = 2 [default = OPTIONALNESS_REQUIRED]; + + // The type of the field. + optional TypePb type = 3; + + // The field kind. E.g. LVAL. + optional FieldKind kind = 4 [default = FIELD_KIND_UNSPECIFIED]; + + // Whether the field should be ignored in the IR op. + // + // For example, we might want to ignore the source location fields, since MLIR + // has builtin support for source location. In the AST <-> IR conversion, we + // need to supply some manually-written code to convert between source + // location fields and MLIR's source location attributes. + optional bool ignore_in_ir = 5 [default = false]; + + // Whether the field should be enclosed in a region. + // + // For example, if the field is a statement, we should create a nested region + // and enclose the field in it. + // + // Detail: + // + // Normally, an argument of an IR op is logically executed before the op + // itself. For example, consider the following AST: + // + // ``` + // BinaryExpression { + // operator: "+" + // left: Identifier { "a" } + // right: Identifier { "b" } + // } + // ``` + // + // We can transform it into the following IR: + // + // ``` + // %left = jsir.identifier {"a"} + // %right = jsir.identifier {"b"} + // %bin_expr = jsir.binary_expression (%left, %right) + // ``` + // + // We can see that the IR is, in a sense, a post-order traversal of the AST. + // + // However, if an AST node has a statement field, we cannot model it in the IR + // this way. For example, consider the following AST: + // + // ``` + // IfStatement { + // test: Identifier { "a" } + // body: SomeSortOfStatement {} + // } + // ``` + // + // If we are to use the same approach, we would transform it into the + // following IR: + // + // ``` + // %test = jsir.identifier {"a"} + // %body = jsir.some_sort_of_statement + // jsir.if_statement (%test, %body) + // ``` + // + // This is problematic, because: + // + // 1. The semantics of the if statement is that `body` might not be executed. + // However, in this IR it appears that body is always executed. + // + // 2. jsir.some_sort_of_statement shouldn't have a return value. + // However, in this IR it must have one, just for the if_statement to + // reference the op. + // + // Therefore, we need to put body **inside** of the if statement, like this: + // + // ``` + // %test = jsir.identifier {"a"} + // jsir.if_statement (%test) { + // jsir.some_sort_of_statement + // } + // ``` + // + // Therefore, we define `enclose_in_region` to specify that this AST field, + // when converted to an IR op, should be further enclosed in a region. + optional bool enclose_in_region = 6 [default = false]; +} + +// MLIR traits. +enum MlirTrait { + MLIR_TRAIT_INVALID = 0; + MLIR_TRAIT_PURE = 1; // Pure + MLIR_TRAIT_ISOLATED_FROM_ABOVE = 2; // IsolatedFromAbove +} + +// Definition of an AST node type. +message NodeDefPb { + // Name of the node. + // + // Must be PascalCase. + // E.g. "BinaryExpression". + optional string name = 1; + + // Type kind string. + // + // In the JavaScript object version of the AST, a special "type" string + // represents the kind of the node. + // + // interface BinaryExpression <: Expression { + // type: "BinaryExpression"; <============ This field. + // operator: BinaryOperator; + // left: Expression | PrivateName; + // right: Expression; + // } + // + // The "type" string only has a concrete value in leaf types. + // + // interface Expression <: Node { } <======= No "type" value defined. + // + // The existence of a concrete "type" value suggests that this is a leaf type. + optional string type = 2; + + // Parent nodes to inherit from. + repeated string parents = 3; + + // Fields defined by this node. + // Not including fields in parents. + repeated FieldDefPb fields = 4; + + // If true, automatically generate the corresponding op. + // If false, the op is expected to be manually written. + optional bool should_generate_ir_op = 5; + + // Supported kinds. Each kind leads to a different IR op. + repeated FieldKind kinds = 6; + + // Whether this op has control flow. If so, we will define a high-level IR op, + // and a low-level IR op. + optional bool has_control_flow = 7; + + // [Optional] Custom MLIR op name. + // + // By default, each AST node corresponds to an equivalent op or interface: + // - JsNumericLiteral <=> JsirNumericLiteralOp + // - JsExpression <=> JsirExpressionOpInterface + // + // If a custom op name is specified, then the corresponding MLIR op will not + // be generated, and its ancestors will not have corresponding interfaces. + // - JsNumericLiteral <=> mlir::arith::ConstantOp + // - JsExpression <=> mlir::Value + optional string ir_op_name = 8; + + // Whether this node has a fold operation + optional bool has_fold = 9; + + // Additional MLIR traits to add to the op definition. + repeated MlirTrait additional_mlir_traits = 10 [packed = true]; +} + +// Definition of a tagged union node type. +// +// SWC likes to use enum types to model inheritance. This union type simplifies +// the AST definitions by allowing us to mimic this model and implicitly add the +// associated union type as a parent to all the enum members. +message UnionTypePb { + // Name of the union node. + // + // Must be PascalCase. + // E.g. "BinaryExpression". + optional string name = 1; + + // Parent nodes to inherit from. + repeated string parents = 2; + + // Types which are members of this union. + repeated string types = 3; +} + +// Top-level AST definition. +message AstDefPb { + // The shortened language name. + // E.g. "js" + optional string lang_name = 1; + + repeated EnumDefPb enums = 2; + + repeated NodeDefPb nodes = 3; + + repeated UnionTypePb union_types = 4; +} diff --git a/maldoca/astgen/ast_gen.cc b/maldoca/astgen/ast_gen.cc new file mode 100644 index 0000000..9f1a792 --- /dev/null +++ b/maldoca/astgen/ast_gen.cc @@ -0,0 +1,4321 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/ast_gen.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/base/path.h" +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "maldoca/astgen/ast_def.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/symbol.h" +#include "maldoca/astgen/type.h" +#include "google/protobuf/io/printer.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace maldoca { +namespace { + +constexpr absl::string_view kOsValueVariableName = "os"; +constexpr absl::string_view kJsonValueVariableName = "json"; + +std::string GetAstHeaderPath(absl::string_view ast_path) { + return JoinPath(ast_path, "ast.generated.h"); +} + +// FieldIs{Argument,Region}: +// +// If a field has ignore_in_ir(), then we don't define anything in the op. +// +// Example: Node::start does not lead to any argument/region in JSIR because we +// want to store the information in mlir::Location. +// +// If a field has enclose_in_region(), then it's an MLIR "region"; otherwise +// it's an MLIR "argument". +// +// An argument is either an mlir::Attribute or an mlir::Value; +// A region is an mlir::Region. +// +// See FieldDefPb::enclose_in_region for why we need to enclose certain fields +// in a region. +bool FieldIsArgument(const FieldDef *field) { + return !field->ignore_in_ir() && !field->enclose_in_region(); +} + +bool FieldIsRegion(const FieldDef *field) { + return !field->ignore_in_ir() && field->enclose_in_region(); +} + +// Gets the name of the *RegionEndOp. +// - For an lval or rval (expression): ExprRegionEndOp. +// - For a list of lvals or rvals (expressions): ExprsRegionEndOp. +Symbol GetRegionEndOp(const AstDef &ast, const FieldDef &field) { + auto ir_name = Symbol(absl::StrCat(ast.lang_name(), "ir")); + + Symbol region_end_op; + switch (field.kind()) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + LOG(FATAL) << "Unsupported FieldKind: " << field.kind(); + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: { + if (field.type().IsA()) { + return ir_name + "ExprsRegionEndOp"; + } else { + return ir_name + "ExprRegionEndOp"; + } + } + case FIELD_KIND_STMT: { + return Symbol{}; + } + } +} + +MaybeNull OptionalnessToMaybeNull(Optionalness optionalness) { + switch (optionalness) { + case OPTIONALNESS_UNSPECIFIED: + case OPTIONALNESS_REQUIRED: + return MaybeNull::kNo; + case OPTIONALNESS_MAYBE_NULL: + case OPTIONALNESS_MAYBE_UNDEFINED: + return MaybeNull::kYes; + } +} + +struct TabPrinterOptions { + std::function print_prefix = nullptr; + std::function print_separator = nullptr; + std::function print_postfix = nullptr; +}; + +class TabPrinter : private TabPrinterOptions { + public: + explicit TabPrinter(TabPrinterOptions options) + : TabPrinterOptions(std::move(options)) {} + + ~TabPrinter() { + if (!is_first_) { + if (print_postfix) { + print_postfix(); + } + } + } + + void Print() { + if (is_first_) { + if (print_prefix) { + print_prefix(); + } + is_first_ = false; + } else { + if (print_separator) { + print_separator(); + } + } + } + + private: + bool is_first_ = true; +}; + +// Consistently unindent lines of code so that the outmost line has no +// indentation. +// +// Example: +// +// Input: +// ``` +// abc +// abc +// abc +// ``` +// +// Output: +// ``` +// abc +// abc +// abc +// ``` +std::string UnIndentedSource(absl::string_view source) { + source = absl::StripTrailingAsciiWhitespace(source); + + std::vector lines = absl::StrSplit(source, '\n'); + + // Remove leading empty lines. + lines.erase(lines.begin(), absl::c_find_if(lines, [](const auto &line) { + return !line.empty(); + })); + + size_t min_indent = absl::c_accumulate( + lines, std::numeric_limits::max(), + [](size_t current_min, const std::string &line) { + size_t first_non_whitespace = line.find_first_not_of(' '); + if (first_non_whitespace == std::string::npos) { + return current_min; + } + return std::min(current_min, first_non_whitespace); + }); + + for (auto &line : lines) { + if (line.size() >= min_indent) { + line.erase(0, min_indent); + } + } + + return absl::StrJoin(lines, "\n"); +} + +} // namespace + +// ============================================================================= +// TsInterfacePrinter +// ============================================================================= + +void TsInterfacePrinter::PrintAst(const AstDef &ast) { + for (const EnumDef &enum_def : ast.enum_defs()) { + PrintEnum(enum_def, ast.lang_name()); + Println(); + } + + for (const auto &name : ast.node_names()) { + const NodeDef &node = *ast.nodes().at(name); + PrintNode(node); + Println(); + } +} + +void TsInterfacePrinter::PrintEnum(const EnumDef &enum_def, + absl::string_view lang_name) { + auto vars = WithVars({ + {"EnumName", enum_def.name().ToPascalCase()}, + }); + + Println("type $EnumName$ ="); + { + auto indent = WithIndent(4); + for (const EnumMemberDef &member : enum_def.members()) { + auto vars = WithVars({ + {"string_value", absl::CEscape(member.string_value())}, + }); + + Println("| \"$string_value$\""); + } + } +} + +void TsInterfacePrinter::PrintNode(const NodeDef &node) { + auto vars = WithVars({ + {"NodeType", node.name()}, + }); + Print("interface $NodeType$"); + + if (!node.parents().empty()) { + Print(" <: "); + + TabPrinter separator_printer{{ + .print_separator = [&] { Print(", "); }, + }}; + for (const NodeDef *parent : node.parents()) { + separator_printer.Print(); + Print(parent->name()); + } + } + + Println(" {"); + { + auto indent = WithIndent(); + for (const FieldDef &field : node.fields()) { + PrintFieldDef(field); + } + } + Println("}"); +} + +void TsInterfacePrinter::PrintFieldDef(const FieldDef &field) { + Print(field.name().ToCamelCase()); + + if (field.optionalness() == OPTIONALNESS_MAYBE_UNDEFINED) { + Print("?"); + } + + Print(": "); + + MaybeNull maybe_null = field.optionalness() == OPTIONALNESS_MAYBE_NULL + ? MaybeNull::kYes + : MaybeNull::kNo; + Print(field.type().JsType(maybe_null)); + + Println(); +} + +std::string PrintTsInterface(const AstDef &ast) { + std::string ts_interface; + { + google::protobuf::io::StringOutputStream os(&ts_interface); + TsInterfacePrinter printer(&os); + printer.PrintAst(ast); + } + return ts_interface; +} + +// ============================================================================= +// CcPrinterBase +// ============================================================================= + +void CcPrinterBase::PrintLicense() { + static const auto *kCcLicenceString = new std::string{UnIndentedSource(R"cc( + // Copyright 2024 Google LLC + // + // Licensed under the Apache License, Version 2.0 (the "License"); + // you may not use this file except in compliance with the License. + // You may obtain a copy of the License at + // + // https://www.apache.org/licenses/LICENSE-2.0 + // + // Unless required by applicable law or agreed to in writing, software + // distributed under the License is distributed on an "AS IS" BASIS, + // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + // See the License for the specific language governing permissions and + // limitations under the License. + )cc")}; + + Println(kCcLicenceString->c_str()); +} + +void CcPrinterBase::PrintEnterNamespace(absl::string_view cc_namespace) { + for (absl::string_view cc_namespace_piece : + absl::StrSplit(cc_namespace, "::")) { + auto vars = WithVars({ + {"cc_namespace_piece", std::string(cc_namespace_piece)}, + }); + Println("namespace $cc_namespace_piece$ {"); + } +} + +void CcPrinterBase::PrintExitNamespace(absl::string_view cc_namespace) { + std::vector pieces = absl::StrSplit(cc_namespace, "::"); + for (auto it = pieces.rbegin(); it != pieces.rend(); ++it) { + auto vars = WithVars({ + {"cc_namespace_piece", std::string(*it)}, + }); + Println("} // namespace $cc_namespace_piece$"); + } +} + +static std::string ToHeaderGuard(absl::string_view header_path) { + std::string header_guard = absl::AsciiStrToUpper(header_path); + absl::StrReplaceAll({{"/", "_"}, {".", "_"}}, &header_guard); + absl::StrAppend(&header_guard, "_"); + return header_guard; +} + +void CcPrinterBase::PrintEnterHeaderGuard(absl::string_view header_path) { + auto vars = WithVars({ + {"HEADER_GUARD", ToHeaderGuard(header_path)}, + }); + + Println("#ifndef $HEADER_GUARD$"); + Println("#define $HEADER_GUARD$"); +} + +void CcPrinterBase::PrintExitHeaderGuard(absl::string_view header_path) { + auto vars = WithVars({ + {"HEADER_GUARD", ToHeaderGuard(header_path)}, + }); + + Println("#endif // $HEADER_GUARD$"); +} + +void CcPrinterBase::PrintIncludeHeader(absl::string_view header_path) { + auto vars = WithVars({ + {"header_path", std::string(header_path)}, + }); + + Println("#include \"$header_path$\""); +} + +void CcPrinterBase::PrintIncludeHeaders(std::vector header_paths) { + for (absl::string_view header_path : header_paths) { + PrintIncludeHeader(header_path); + } +} + +void CcPrinterBase::PrintTitle(absl::string_view title) { + std::vector commented_lines; + for (absl::string_view line : absl::StrSplit(title, '\n')) { + if (line.empty()) { + commented_lines.push_back("//"); + } else { + commented_lines.push_back(absl::StrCat("// ", line)); + } + } + std::string commented_title = absl::StrJoin(commented_lines, "\n"); + + auto vars = WithVars({ + {"CommentedTitle", commented_title}, + }); + + static const auto *kCode = new std::string(absl::StripAsciiWhitespace(R"( +// ============================================================================= +$CommentedTitle$ +// ============================================================================= + )")); + + Println(kCode->c_str()); +} + +void CcPrinterBase::PrintCodeGenerationWarning() { + PrintTitle("STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED."); +} + +// ============================================================================= +// AstHeaderPrinter +// ============================================================================= + +void AstHeaderPrinter::PrintAst(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path) { + auto header_path = GetAstHeaderPath(ast_path); + + PrintLicense(); + Println(); + + PrintCodeGenerationWarning(); + Println(); + + PrintEnterHeaderGuard(header_path); + Println(); + + Println("// IWYU pragma: begin_keep"); + Println("// NOLINTBEGIN(whitespace/line_length)"); + Println("// clang-format off"); + Println(); + + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println(); + + PrintIncludeHeader("absl/status/statusor.h"); + PrintIncludeHeader("absl/strings/string_view.h"); + PrintIncludeHeader("nlohmann/json.hpp"); + Println(); + + PrintEnterNamespace(cc_namespace); + Println(); + + for (const EnumDef &enum_def : ast.enum_defs()) { + PrintEnum(enum_def, ast.lang_name()); + Println(); + } + + for (const NodeDef *node : ast.topological_sorted_nodes()) { + PrintNode(*node, ast.lang_name()); + Println(); + } + + Println("// clang-format on"); + Println("// NOLINTEND(whitespace/line_length)"); + Println("// IWYU pragma: end_keep"); + Println(); + + PrintExitNamespace(cc_namespace); + Println(); + + PrintExitHeaderGuard(header_path); +} + +void AstHeaderPrinter::PrintEnum(const EnumDef &enum_def, + absl::string_view lang_name) { + auto vars = WithVars({ + {"EnumName", (Symbol(lang_name) + enum_def.name()).ToPascalCase()}, + {"enum_name", enum_def.name().ToSnakeCase()}, + }); + + Println("enum class $EnumName$ {"); + { + auto indent = WithIndent(); + for (const EnumMemberDef &member : enum_def.members()) { + auto vars = WithVars({ + {"kMemberName", (Symbol("k") + member.name()).ToCamelCase()}, + }); + + Println("$kMemberName$,"); + } + } + Println("};"); + Println(); + + Println("absl::string_view $EnumName$ToString($EnumName$ $enum_name$);"); + Println( + "absl::StatusOr<$EnumName$> StringTo$EnumName$(absl::string_view s);"); +} + +void AstHeaderPrinter::PrintNode(const NodeDef &node, + absl::string_view lang_name) { + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node.name()).ToPascalCase()}, + {"json_variable", kJsonValueVariableName}, + {"os_variable", kOsValueVariableName}, + }); + + if (node.node_type_enum().has_value()) { + PrintEnum(*node.node_type_enum().value(), lang_name); + Println(); + } + + Print("class $NodeType$"); + if (!node.parents().empty()) { + Print(" : "); + TabPrinter separator_printer{{ + .print_separator = [&] { Print(", "); }, + }}; + for (const NodeDef *parent : node.parents()) { + auto vars = WithVars({ + {"BaseType", (Symbol(lang_name) + parent->name()).ToPascalCase()}, + }); + + separator_printer.Print(); + Print("public virtual $BaseType$"); + } + } + Println(" {"); + + // Always print "public:" because the declaration of FromJson() always + // exists. + Println(" public:"); + { + auto indent = WithIndent(); + + // Constructor + if (!node.aggregated_fields().empty()) { + PrintConstructor(node, lang_name); + Println(); + } + + // Destructor + if (node.parents().empty() && !node.children().empty()) { + Println("virtual ~$NodeType$() = default;"); + Println(); + } + + // Get type enum. + if (node.node_type_enum().has_value()) { + auto node_type_enum_name = node.node_type_enum().value()->name(); + auto vars = WithVars({ + {"NodeTypeEnum", + (Symbol(lang_name) + node_type_enum_name).ToPascalCase()}, + {"node_type_enum", node_type_enum_name.ToCcVarName()}, + }); + + Println("virtual $NodeTypeEnum$ $node_type_enum$() const = 0;"); + Println(); + + } else if (node.children().empty()) { + for (const NodeDef *ancestor : node.ancestors()) { + if (!ancestor->node_type_enum().has_value()) { + continue; + } + + auto root_type_enum_name = ancestor->node_type_enum().value()->name(); + auto vars = WithVars({ + {"RootTypeEnum", + (Symbol(lang_name) + root_type_enum_name).ToPascalCase()}, + {"root_type_enum", root_type_enum_name.ToCcVarName()}, + {"NodeTypeNoLang", Symbol(node.name()).ToPascalCase()}, + }); + + Println("$RootTypeEnum$ $root_type_enum$() const override {"); + Println(" return $RootTypeEnum$::k$NodeTypeNoLang$;"); + Println("}"); + Println(); + } + } + + // Serialize + if (node.parents().empty()) { + if (node.children().empty()) { + // Non-virtual. + Println("void Serialize(std::ostream& $os_variable$) const;"); + Println(); + } else { + // Virtual base. + // We define a pure virtual function here, and override it in leaf + // types. + Println( + "virtual void Serialize(std::ostream& $os_variable$) " + "const = 0;"); + Println(); + } + } else { + if (node.children().empty()) { + // Leaf type. + // We override the virtual function. + Println( + "void Serialize(std::ostream& $os_variable$) " + "const override;"); + Println(); + } else { + // Non-leaf type - skipped. + // We only override in leaf types. Here it's still pure virtual. + } + } + + // FromJson + Println( + "static absl::StatusOr> FromJson(" + "const nlohmann::json& $json_variable$);"); + Println(); + + // Getters and setters. + for (const FieldDef &field : node.fields()) { + PrintGetterSetterDeclarations(field, lang_name); + Println(); + } + } + + Println(" protected:"); + { + auto indent = WithIndent(); + + // SerializeFields + Println("// Internal function used by Serialize()."); + Println("// Sets the fields defined in this class."); + Println("// Does not set fields defined in ancestors."); + Println( + "void SerializeFields(std::ostream& $os_variable$, " + "bool &needs_comma) const;"); + + // GetFromJson() functions. + if (!node.fields().empty()) { + Println(); + Println("// Internal functions used by FromJson()."); + Println("// Extracts a field from a JSON object."); + for (const FieldDef &field : node.fields()) { + PrintGetFromJson(field, lang_name); + } + } + } + + // Print member variables. + if (!node.fields().empty()) { + Println(); + Println(" private:"); + { + auto indent = WithIndent(); + for (const FieldDef &field : node.fields()) { + PrintMemberVariable(field, lang_name); + } + } + } + + Println("};"); +} + +void AstHeaderPrinter::PrintConstructor(const NodeDef &node, + absl::string_view lang_name) { + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node.name()).ToPascalCase()}, + }); + Print("explicit $NodeType$("); + if (!node.aggregated_fields().empty()) { + Println(); + { + auto indent = WithIndent(4); + TabPrinter separator_printer{{ + .print_separator = [this] { Print(",\n"); }, + }}; + for (const FieldDef *field : node.aggregated_fields()) { + auto vars = WithVars({ + {"cc_type", CcType(*field)}, + {"field_name", field->name().ToCcVarName()}, + }); + + separator_printer.Print(); + Print("$cc_type$ $field_name$"); + } + } + } + Println(");"); +} + +void AstHeaderPrinter::PrintGetterSetterDeclarations( + const FieldDef &field, absl::string_view lang_name) { + std::string cc_getter_type = CcMutableGetterType(field); + std::string cc_const_getter_type = CcConstGetterType(field); + + auto vars = WithVars({ + {"cc_getter_type", cc_getter_type}, + {"cc_const_getter_type", cc_const_getter_type}, + {"cc_type", CcType(field)}, + {"field_name", field.name().ToCcVarName()}, + }); + + // If the mutable getter would return the same type as the const getter, skip + // the mutable getter. + if (cc_getter_type != cc_const_getter_type) { + Println("$cc_getter_type$ $field_name$();"); + } + Println("$cc_const_getter_type$ $field_name$() const;"); + Println("void set_$field_name$($cc_type$ $field_name$);"); +} + +void AstHeaderPrinter::PrintMemberVariable(const FieldDef &field, + absl::string_view lang_name) { + auto vars = WithVars({ + {"cc_type", CcType(field)}, + {"field_name", field.name().ToCcVarName()}, + }); + + Println("$cc_type$ $field_name$_;"); +} + +void AstHeaderPrinter::PrintGetFromJson(const FieldDef &field, + absl::string_view lang_name) { + auto vars = WithVars({ + {"cc_type", CcType(field)}, + {"FieldName", field.name().ToPascalCase()}, + {"os_variable", kOsValueVariableName}, + }); + + Println( + "static absl::StatusOr<$cc_type$> " + "Get$FieldName$(const nlohmann::json& $json_variable$);"); +} + +std::string PrintAstHeader(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path) { + std::string str; + { + google::protobuf::io::StringOutputStream os(&str); + AstHeaderPrinter printer(&os); + printer.PrintAst(ast, cc_namespace, ast_path); + } + + return str; +} + +// ============================================================================= +// AstSourcePrinter +// ============================================================================= + +void AstSourcePrinter::PrintAst(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path) { + auto header_path = GetAstHeaderPath(ast_path); + + PrintLicense(); + Println(); + + PrintCodeGenerationWarning(); + Println(); + + PrintIncludeHeader(header_path); + Println(); + + Println("// IWYU pragma: begin_keep"); + Println("// NOLINTBEGIN(whitespace/line_length)"); + Println("// clang-format off"); + Println(); + + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println(); + + PrintIncludeHeader("absl/container/flat_hash_map.h"); + PrintIncludeHeader("absl/memory/memory.h"); + PrintIncludeHeader("absl/log/log.h"); + PrintIncludeHeader("absl/status/status.h"); + PrintIncludeHeader("absl/status/statusor.h"); + PrintIncludeHeader("absl/strings/str_cat.h"); + PrintIncludeHeader("absl/strings/string_view.h"); + PrintIncludeHeader("nlohmann/json.hpp"); + PrintIncludeHeader("maldoca/base/status_macros.h"); + Println(); + + PrintEnterNamespace(cc_namespace); + Println(); + + for (const EnumDef &enum_def : ast.enum_defs()) { + PrintEnum(enum_def, ast.lang_name()); + Println(); + } + + for (const NodeDef *node : ast.topological_sorted_nodes()) { + PrintNode(*node, ast.lang_name()); + } + + Println("// clang-format on"); + Println("// NOLINTEND(whitespace/line_length)"); + Println("// IWYU pragma: end_keep"); + Println(); + + PrintExitNamespace(cc_namespace); +} + +void AstSourcePrinter::PrintEnum(const EnumDef &enum_def, + absl::string_view lang_name) { + auto vars = WithVars({ + {"EnumName", (Symbol(lang_name) + enum_def.name()).ToPascalCase()}, + {"enum_name", enum_def.name().ToSnakeCase()}, + }); + + Println("absl::string_view $EnumName$ToString($EnumName$ $enum_name$) {"); + { + auto indent = WithIndent(); + Println("switch ($enum_name$) {"); + { + auto indent = WithIndent(); + for (const EnumMemberDef &member : enum_def.members()) { + auto vars = WithVars({ + {"kMemberName", (Symbol("k") + member.name()).ToCamelCase()}, + {"string_value", absl::CEscape(member.string_value())}, + }); + + Println("case $EnumName$::$kMemberName$:"); + Println(" return \"$string_value$\";"); + } + } + Println("}"); + } + Println("}"); + Println(); + + Println( + "absl::StatusOr<$EnumName$> StringTo$EnumName$(absl::string_view s) {"); + { + auto indent = WithIndent(); + + Println( + "static const auto *kMap = " + "new absl::flat_hash_map {"); + { + auto indent = WithIndent(4); + for (const EnumMemberDef &member : enum_def.members()) { + auto vars = WithVars({ + {"kMemberName", (Symbol("k") + member.name()).ToCamelCase()}, + {"string_value", absl::CEscape(member.string_value())}, + }); + + Println("{\"$string_value$\", $EnumName$::$kMemberName$},"); + } + } + Println("};"); + Println(); + + const auto code = UnIndentedSource(R"( +auto it = kMap->find(s); +if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for $EnumName$: ", s)); +} +return it->second; + )"); + Println(code); + } + Println("}"); +} + +void AstSourcePrinter::PrintNode(const NodeDef &node, + absl::string_view lang_name) { + PrintTitle((Symbol(lang_name) + node.name()).ToPascalCase()); + Println(); + + auto vars = WithVars({ + {"NodeType", ClassType(Symbol(node.name()), lang_name).CcType()}, + }); + + if (node.node_type_enum().has_value()) { + PrintEnum(*node.node_type_enum().value(), lang_name); + Println(); + } + + if (!node.aggregated_fields().empty()) { + PrintConstructor(node, lang_name); + Println(); + } + + for (const FieldDef &field : node.fields()) { + const Type &type = field.type(); + bool is_optional = field.optionalness() != OPTIONALNESS_REQUIRED; + + std::string cc_getter_type = CcMutableGetterType(field); + std::string cc_const_getter_type = CcConstGetterType(field); + + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node.name()).ToPascalCase()}, + {"cc_getter_type", cc_getter_type}, + {"cc_const_getter_type", cc_const_getter_type}, + {"cc_type", CcType(field)}, + {"field_name", field.name().ToCcVarName()}, + }); + + // If both the mutable getter and const getter would have the same return + // type, then we just skip the mutable getter and only keep the const + // getter. + if (cc_getter_type != cc_const_getter_type) { + Println("$cc_getter_type$ $NodeType$::$field_name$() {"); + { + auto indent = WithIndent(); + PrintGetterBody(field.name(), type, is_optional); + } + Println("}"); + Println(); + } + + Println("$cc_const_getter_type$ $NodeType$::$field_name$() const {"); + { + auto indent = WithIndent(); + PrintGetterBody(field.name(), type, is_optional); + } + Println("}"); + Println(); + + Println("void $NodeType$::set_$field_name$($cc_type$ $field_name$) {"); + { + auto indent = WithIndent(); + PrintSetterBody(field.name(), type, is_optional); + } + Println("}"); + Println(); + } +} + +void AstSourcePrinter::PrintConstructor(const NodeDef &node, + absl::string_view lang_name) { + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node.name()).ToPascalCase()}, + }); + Print("$NodeType$::$NodeType$("); + if (!node.aggregated_fields().empty()) { + Println(); + auto indent = WithIndent(4); + + TabPrinter separator_printer{{ + .print_separator = [this] { Print(",\n"); }, + }}; + for (const FieldDef *field : node.aggregated_fields()) { + auto vars = WithVars({ + {"cc_type", CcType(*field)}, + {"field_name", field->name().ToCcVarName()}, + }); + + separator_printer.Print(); + Print("$cc_type$ $field_name$"); + } + } + Println(")"); + + { + auto indent = WithIndent(4); + + TabPrinter tab_printer{{ + .print_prefix = + [&] { + Print(": "); + Indent(); + }, + .print_separator = [&] { Print(",\n"); }, + .print_postfix = [&] { Outdent(); }, + }}; + for (const NodeDef *ancestor : node.ancestors()) { + tab_printer.Print(); + + auto vars = WithVars({ + {"AncestorType", + (Symbol(lang_name) + ancestor->name()).ToPascalCase()}, + }); + Print("$AncestorType$("); + + TabPrinter ancestor_tab_printer{{ + .print_separator = [&] { Print(", "); }, + }}; + for (const FieldDef *field : ancestor->aggregated_fields()) { + ancestor_tab_printer.Print(); + + auto vars = WithVars({ + {"field_name", field->name().ToCcVarName()}, + }); + Print("std::move($field_name$)"); + } + + Print(")"); + } + + for (const FieldDef &field : node.fields()) { + auto vars = WithVars({ + {"field_name", field.name().ToCcVarName()}, + }); + + tab_printer.Print(); + Print("$field_name$_(std::move($field_name$))"); + } + } + + Println(" {}"); +} + +void AstSourcePrinter::PrintGetterBody(const std::string &cc_expr, + const Type &type) { + auto vars = WithVars({ + {"cc_expr", cc_expr}, + }); + + switch (type.kind()) { + case TypeKind::kBuiltin: { + Println("return $cc_expr$;"); + break; + } + + case TypeKind::kEnum: { + Println("return $cc_expr$;"); + break; + } + + case TypeKind::kClass: { + Println("return $cc_expr$.get();"); + break; + } + + case TypeKind::kVariant: { + const auto &variant_type = static_cast(type); + + Println("switch ($cc_expr$.index()) {"); + { + auto indent = WithIndent(); + + for (size_t i = 0; i != variant_type.types().size(); ++i) { + auto vars = WithVars({ + {"i", std::to_string(i)}, + }); + const ScalarType &type = *variant_type.types().at(i); + + Println("case $i$: {"); + { + auto indent = WithIndent(); + PrintGetterBody(absl::StrFormat("std::get<%zu>(%s)", i, cc_expr), + type); + } + Println("}"); + } + + Println("default:"); + Println(" LOG(FATAL) << \"Unreachable code.\";"); + } + Println("}"); + + break; + } + + case TypeKind::kList: { + Println("return &$cc_expr$;"); + break; + } + } +} + +void AstSourcePrinter::PrintGetterBody(const Symbol &field_name, + const Type &type, bool is_optional) { + if (is_optional) { + auto vars = WithVars({ + {"field_name", field_name.ToCcVarName()}, + }); + + Println("if (!$field_name$_.has_value()) {"); + Println(" return std::nullopt;"); + Println("} else {"); + { + auto indent = WithIndent(); + auto value_cc_expr = absl::StrCat(field_name.ToCcVarName(), "_.value()"); + PrintGetterBody(value_cc_expr, type); + } + Println("}"); + + } else { + PrintGetterBody(absl::StrCat(field_name.ToCcVarName(), "_"), type); + } +} + +void AstSourcePrinter::PrintSetterBody(const Symbol &field_name, + const Type &type, bool is_optional) { + auto vars = WithVars({ + {"field_name", field_name.ToCcVarName()}, + }); + + if (type.IsA()) { + const auto &builtin_type = static_cast(type); + switch (builtin_type.builtin_kind()) { + case BuiltinTypeKind::kBool: + case BuiltinTypeKind::kDouble: + Println("$field_name$_ = $field_name$;"); + return; + default: + break; + } + } + + if (type.IsA()) { + Println("$field_name$_ = $field_name$;"); + return; + } + + Println("$field_name$_ = std::move($field_name$);"); +} + +std::string PrintAstSource(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path) { + std::string str; + { + google::protobuf::io::StringOutputStream os(&str); + AstSourcePrinter printer(&os); + printer.PrintAst(ast, cc_namespace, ast_path); + } + + return str; +} + +// ============================================================================= +// AstSerializePrinter +// ============================================================================= + +void AstSerializePrinter::PrintAst(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path) { + auto vars = WithVars({ + {"os_variable", kOsValueVariableName}, + }); + + auto header_path = GetAstHeaderPath(ast_path); + + PrintLicense(); + Println(); + + PrintCodeGenerationWarning(); + Println(); + + Println("// IWYU pragma: begin_keep"); + Println("// NOLINTBEGIN(whitespace/line_length)"); + Println("// clang-format off"); + Println(); + + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println(); + + PrintIncludeHeaders({ + std::string(header_path), + "absl/log/log.h", + "absl/memory/memory.h", + "absl/status/status.h", + "absl/strings/string_view.h", + "nlohmann/json.hpp", + "maldoca/base/status_macros.h", + }); + Println(); + + PrintEnterNamespace(cc_namespace); + Println(); + + Println( + R"(void MaybeAddComma(std::ostream &$os_variable$, bool &needs_comma) { + if (needs_comma) { + $os_variable$ << ","; + } + needs_comma = true; +} +)"); + + for (const auto &node : ast.topological_sorted_nodes()) { + PrintTitle((Symbol(ast.lang_name()) + node->name()).ToPascalCase()); + Println(); + + PrintSerializeFieldsFunction(*node, ast.lang_name()); + Println(); + + if (node->children().empty()) { + PrintSerializeFunction(*node, ast.lang_name()); + Println(); + } + } + + Println("// clang-format on"); + Println("// NOLINTEND(whitespace/line_length)"); + Println("// IWYU pragma: end_keep"); + Println(); + + PrintExitNamespace(cc_namespace); +} + +void AstSerializePrinter::PrintBuiltinSerialize(const BuiltinType &type, + const std::string &lhs, + const std::string &rhs) { + auto vars = WithVars({ + {"os_variable", kOsValueVariableName}, + {"lhs", lhs}, + {"rhs", rhs}, + }); + + if (!lhs.empty()) { + Println("$os_variable$ << $lhs$ << (nlohmann::json($rhs$)).dump();"); + } else { + Println("$os_variable$ << (nlohmann::json($rhs$)).dump();"); + } +} + +void AstSerializePrinter::PrintEnumSerialize(const EnumType &type, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name) { + auto vars = WithVars({ + {"os_variable", kOsValueVariableName}, + {"lhs", lhs}, + {"rhs", rhs}, + {"EnumName", (Symbol(lang_name) + type.name()).ToPascalCase()}, + }); + + if (!lhs.empty()) { + Println( + R"($os_variable$ << $lhs$ << "\"" << $EnumName$ToString($rhs$) << "\"";)"); + } else { + Println(R"($os_variable$ << "\"" << $EnumName$ToString($rhs$) << "\"";)"); + } +} + +void AstSerializePrinter::PrintClassSerialize(const ClassType &type, + const std::string &lhs, + const std::string &rhs) { + auto vars = WithVars({ + {"os_variable", kOsValueVariableName}, + {"lhs", lhs}, + {"rhs", rhs}, + }); + + if (!lhs.empty()) { + Println("$os_variable$ << $lhs$;"); + } + Println("$rhs$->Serialize($os_variable$);"); +} + +void AstSerializePrinter::PrintVariantSerialize(const VariantType &variant_type, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name) { + auto vars = WithVars({ + {"lhs", lhs}, + {"rhs", rhs}, + }); + + Println("switch ($rhs$.index()) {"); + { + auto indent = WithIndent(); + for (size_t i = 0; i != variant_type.types().size(); ++i) { + auto vars = WithVars({ + {"i", std::to_string(i)}, + }); + + Println("case $i$: {"); + { + auto indent = WithIndent(); + const ScalarType &type = *variant_type.types()[i]; + PrintSerialize(type, lhs, absl::StrFormat("std::get<%zu>(%s)", i, rhs), + lang_name); + Println("break;"); + } + + Println("}"); + } + + Println("default:"); + Println(" LOG(FATAL) << \"Unreachable code.\";"); + } + Println("}"); +} + +void AstSerializePrinter::PrintListSerialize(const ListType &list_type, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name) { + constexpr char kRhsElement[] = "element"; + CHECK_NE(lhs, kRhsElement); + CHECK_NE(rhs, kRhsElement); + + constexpr char kLhsElement[] = "element_json"; + CHECK_NE(lhs, kLhsElement); + CHECK_NE(rhs, kLhsElement); + + auto vars = WithVars({ + {"os_variable", kOsValueVariableName}, + {"lhs", lhs}, + {"rhs", rhs}, + {"lhs_element", kLhsElement}, + {"rhs_element", kRhsElement}, + }); + + if (!lhs.empty()) { + Println(R"($os_variable$ << $lhs$ << "[";)"); + } else { + Println(R"($os_variable$ << "[";)"); + } + Println("{"); + { + auto indent = WithIndent(); + + Println("bool needs_comma = false;"); + Println("for (const auto& $rhs_element$ : $rhs$) {"); + { + auto indent = WithIndent(); + Println("MaybeAddComma($os_variable$, needs_comma);"); + PrintNullableToJson(list_type.element_type(), + list_type.element_maybe_null(), "", kRhsElement, + lang_name); + } + Println("}"); + } + Println("}"); + Println(R"($os_variable$ << "]";)"); +} + +void AstSerializePrinter::PrintSerialize(const Type &type, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name) { + switch (type.kind()) { + case TypeKind::kBuiltin: { + const auto &builtin_type = static_cast(type); + PrintBuiltinSerialize(builtin_type, lhs, rhs); + break; + } + + case TypeKind::kEnum: { + const auto &enum_type = static_cast(type); + PrintEnumSerialize(enum_type, lhs, rhs, lang_name); + break; + } + + case TypeKind::kClass: { + const auto &class_type = static_cast(type); + PrintClassSerialize(class_type, lhs, rhs); + break; + } + + case TypeKind::kVariant: { + const auto &variant_type = static_cast(type); + PrintVariantSerialize(variant_type, lhs, rhs, lang_name); + break; + } + + case TypeKind::kList: { + const auto &list_type = static_cast(type); + PrintListSerialize(list_type, lhs, rhs, lang_name); + break; + } + } +} + +void AstSerializePrinter::PrintNullableToJson(const Type &type, + MaybeNull maybe_null, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name) { + switch (maybe_null) { + case MaybeNull::kNo: { + PrintSerialize(type, lhs, rhs, lang_name); + break; + } + + case MaybeNull::kYes: { + auto vars = WithVars({ + {"os_variable", kOsValueVariableName}, + {"lhs", lhs}, + {"rhs", rhs}, + }); + + Println("if ($rhs$.has_value()) {"); + { + auto indent = WithIndent(); + auto rhs_value = absl::StrCat(rhs, ".value()"); + PrintSerialize(type, lhs, rhs_value, lang_name); + } + Println("} else {"); + { + auto indent = WithIndent(); + + if (!lhs.empty()) { + Println(R"($os_variable$ << $lhs$ << "null";)"); + } else { + Println(R"($os_variable$ << "null";)"); + } + } + Println("}"); + break; + } + } +} + +void AstSerializePrinter::PrintSerializeFieldsFunction( + const NodeDef &node, absl::string_view lang_name) { + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node.name()).ToPascalCase()}, + {"os_variable", kOsValueVariableName}, + }); + + Println( + "void $NodeType$::SerializeFields(std::ostream& $os_variable$, " + "bool &needs_comma) const {"); + { + auto indent = WithIndent(); + + for (const FieldDef &field : node.fields()) { + // E.g. "\"fieldName\":" + auto lhs = absl::StrFormat(R"("\"%s\":")", field.name().ToCamelCase()); + + // E.g. field_name_ + auto rhs = absl::StrCat(field.name().ToCcVarName(), "_"); + + switch (field.optionalness()) { + case OPTIONALNESS_UNSPECIFIED: { + LOG(FATAL) << "Invalid Optionalness. Should be a bug."; + break; + } + + case OPTIONALNESS_REQUIRED: { + Println("MaybeAddComma($os_variable$, needs_comma);"); + PrintSerialize(field.type(), lhs, rhs, lang_name); + break; + } + + case OPTIONALNESS_MAYBE_UNDEFINED: { + auto vars = WithVars({ + {"rhs", rhs}, + }); + + // If == std::nullopt, the assignment does not happen. + Println("if ($rhs$.has_value()) {"); + { + auto indent = WithIndent(); + auto rhs_value = absl::StrCat(rhs, ".value()"); + Println("MaybeAddComma($os_variable$, needs_comma);"); + PrintSerialize(field.type(), lhs, rhs_value, lang_name); + } + Println("}"); + + break; + } + case OPTIONALNESS_MAYBE_NULL: { + Println("MaybeAddComma($os_variable$, needs_comma);"); + PrintNullableToJson(field.type(), MaybeNull::kYes, lhs, rhs, + lang_name); + break; + } + } + } + } + Println("}"); +} + +void AstSerializePrinter::PrintSerializeFunction(const NodeDef &node, + absl::string_view lang_name) { + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node.name()).ToPascalCase()}, + {"NodeTypeNoLangName", node.name()}, + {"os_variable", kOsValueVariableName}, + }); + + Println("void $NodeType$::Serialize(std::ostream& $os_variable$) const {"); + { + auto indent = WithIndent(); + + Println(R"($os_variable$ << "{";)"); + Println("{"); + { + auto indent = WithIndent(); + Println("bool needs_comma = false;"); + + // The "type" field. + if (!node.parents().empty() || !node.children().empty()) { + Println("MaybeAddComma($os_variable$, needs_comma);"); + Println(R"($os_variable$ << "\"type\":\"$NodeTypeNoLangName$\"";)"); + } + + // Assign fields of ancestors of this node. + for (const NodeDef *ancestor : node.ancestors()) { + auto vars = WithVars({ + {"AncestorType", + (Symbol(lang_name) + ancestor->name()).ToPascalCase()}, + }); + Println( + "$AncestorType$::SerializeFields($os_variable$, " + "needs_comma);"); + } + + // Assign fields of the node itself. + Println("$NodeType$::SerializeFields($os_variable$, needs_comma);"); + } + Println("}"); + + Println(R"($os_variable$ << "}";)"); + } + Println("}"); +} + +std::string PrintAstToJson(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path) { + std::string str; + { + google::protobuf::io::StringOutputStream os(&str); + AstSerializePrinter printer(&os); + printer.PrintAst(ast, cc_namespace, ast_path); + } + + return str; +} + +// ============================================================================= +// AstFromJsonPrinter +// ============================================================================= + +// Helper for printing an if-statement. +// +// Usage: +// IfStmtPrinter printer(...); +// printer.PrintCase({ +// [&] { +// PrintConditionHere(); +// }, +// [&] { +// PrintBodyHere(); +// }, +// }); +// printer.PrintCase({ +// [&] { +// PrintAnotherConditionHere(); +// }, +// [&] { +// PrintAnotherBodyHere(); +// }, +// }); +// +// This helper adds the "else" keyword to all subsequent cases. +class IfStmtPrinter { + public: + explicit IfStmtPrinter(google::protobuf::io::Printer *printer) + : is_first_(true), printer_(printer) {} + + struct IfStmtCase { + std::function condition; + std::function body; + }; + + void PrintCase(const IfStmtCase &kase) { + if (is_first_) { + printer_->Print("if ("); + is_first_ = false; + } else { + printer_->Print(" else if ("); + } + kase.condition(); + printer_->Print(") {\n"); + { + auto indent = printer_->WithIndent(); + kase.body(); + } + printer_->Print("}"); + } + + private: + bool is_first_; + google::protobuf::io::Printer *printer_; +}; + +static void GetCheckedClasses(const Type &type, bool is_part_of_variant, + absl::flat_hash_set *node_names) { + switch (type.kind()) { + case TypeKind::kBuiltin: + case TypeKind::kEnum: + break; + case TypeKind::kClass: { + if (is_part_of_variant) { + const auto &class_type = static_cast(type); + node_names->insert(class_type.name().ToPascalCase()); + } + break; + } + case TypeKind::kVariant: { + const auto &variant_type = static_cast(type); + for (const auto &element_type : variant_type.types()) { + GetCheckedClasses(*element_type, /*is_part_of_variant=*/true, + node_names); + } + break; + } + case TypeKind::kList: { + const auto &list_type = static_cast(type); + GetCheckedClasses(list_type.element_type(), is_part_of_variant, + node_names); + break; + } + } +} + +static absl::flat_hash_set GetCheckedClasses(const AstDef &ast) { + absl::flat_hash_set checked_classes; + for (const NodeDef *node : ast.topological_sorted_nodes()) { + for (const FieldDef &field : node->fields()) { + GetCheckedClasses(field.type(), /*is_part_of_variant=*/false, + &checked_classes); + } + } + return checked_classes; +} + +void AstFromJsonPrinter::PrintAst(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path) { + auto vars = WithVars({ + {"json_variable", kJsonValueVariableName}, + }); + + auto header_path = GetAstHeaderPath(ast_path); + + PrintLicense(); + Println(); + + PrintCodeGenerationWarning(); + Println(); + + Println("// NOLINTBEGIN(whitespace/line_length)"); + Println("// clang-format off"); + Println("// IWYU pragma: begin_keep"); + Println(); + + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println(); + + PrintIncludeHeaders({ + std::string(header_path), + "absl/container/flat_hash_set.h", + "absl/memory/memory.h", + "absl/status/status.h", + "absl/status/statusor.h", + "absl/strings/str_cat.h", + "absl/strings/string_view.h", + "maldoca/base/status_macros.h", + "nlohmann/json.hpp", + }); + Println(); + + PrintEnterNamespace(cc_namespace); + Println(); + + bool needs_get_type_function = absl::c_any_of( + ast.topological_sorted_nodes(), + [](const NodeDef *node) { return !node->children().empty(); }); + if (needs_get_type_function) { + static const auto *kGetType = new std::string(UnIndentedSource(R"( +static absl::StatusOr GetType(const nlohmann::json& $json_variable$) { + auto type_it = $json_variable$.find("type"); + if (type_it == $json_variable$.end()) { + return absl::InvalidArgumentError("`type` is undefined."); + } + const nlohmann::json& json_type = type_it.value(); + if (json_type.is_null()) { + return absl::InvalidArgumentError("json_type is null."); + } + if (!json_type.is_string()) { + return absl::InvalidArgumentError("`json_type` expected to be string."); + } + return json_type.get(); +} + )")); + + Println(kGetType->c_str()); + Println(); + } + + absl::flat_hash_set checked_classes = GetCheckedClasses(ast); + + for (const NodeDef *node : ast.topological_sorted_nodes()) { + PrintTitle((Symbol(ast.lang_name()) + node->name()).ToPascalCase()); + Println(); + + if (checked_classes.contains(node->name())) { + PrintTypeChecker(*node); + Println(); + } + + for (const FieldDef &field : node->fields()) { + PrintGetFieldFunction(node->name(), field, ast.lang_name()); + Println(); + } + + PrintFromJsonFunction(*node, ast.lang_name()); + Println(); + } + + Println("// clang-format on"); + Println("// NOLINTEND(whitespace/line_length)"); + Println("// IWYU pragma: end_keep"); + Println(); + + PrintExitNamespace(cc_namespace); +} + +void AstFromJsonPrinter::PrintTypeChecker(const NodeDef &node) { + auto vars = WithVars({ + {"NodeType", std::string(node.name())}, + {"json_variable", kJsonValueVariableName}, + }); + + Println("static bool Is$NodeType$(const nlohmann::json& $json_variable$) {"); + absl::Cleanup end_body = [&] { Println("}"); }; + { + auto indent = WithIndent(); + + Println("if (!$json_variable$.is_object()) {"); + Println(" return false;"); + Println("}"); + + if (node.children().empty() && node.parents().empty()) { + // This is not a virtual class. + Println("return true;"); + return; + } + + const std::string code = UnIndentedSource(R"cc( + auto type_it = $json_variable$.find("type"); + if (type_it == $json_variable$.end()) { + return false; + } + const nlohmann::json &type_json = type_it.value(); + if (!type_json.is_string()) { + return false; + } + const std::string &type = type_json.get(); + )cc"); + Println(code); + + if (!node.leafs().empty()) { + Println( + "static const auto *kTypes = new absl::flat_hash_set{"); + { + auto indent = WithIndent(4); + for (const NodeDef *leaf : node.leafs()) { + auto vars = WithVars({ + {"LeafType", leaf->name()}, + }); + Println("\"$LeafType$\","); + } + } + Println("};"); + Println(); + + Println("return kTypes->contains(type);"); + + } else { + CHECK_EQ(node.name(), node.type().value()); + Println("return type == \"$NodeType$\";"); + } + } +} + +void AstFromJsonPrinter::PrintBuiltinJsonTypeCheck(const BuiltinType &type, + const Symbol &rhs) { + auto vars = WithVars({ + {"rhs", rhs.ToCcVarName()}, + }); + + switch (type.builtin_kind()) { + case BuiltinTypeKind::kBool: + Print("$rhs$.is_boolean()"); + break; + case BuiltinTypeKind::kInt64: + Print("$rhs$.is_number_integer()"); + break; + case BuiltinTypeKind::kDouble: + Print("$rhs$.is_number()"); + break; + case BuiltinTypeKind::kString: + Print("$rhs$.is_string()"); + break; + } +} + +void AstFromJsonPrinter::PrintClassJsonTypeCheck(const ClassType &class_type, + const Symbol &rhs) { + auto vars = WithVars({ + {"ClassType", class_type.name().ToPascalCase()}, + {"rhs", rhs.ToCcVarName()}, + }); + + Print("Is$ClassType$($rhs$)"); +} + +void AstFromJsonPrinter::PrintBuiltinFromJson(Action action, + CheckJsonType check_json_type, + const BuiltinType &builtin_type, + const Symbol &lhs, + const Symbol &rhs) { + auto vars = WithVars({ + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs.ToCcVarName()}, + {"cc_type", builtin_type.CcType()}, + {"js_type", builtin_type.JsType()}, + }); + + switch (check_json_type) { + case CheckJsonType::kYes: { + IfStmtPrinter if_stmt(this); + if_stmt.PrintCase({ + [&] { + // Print if-condition. + Print("!"); + PrintBuiltinJsonTypeCheck(builtin_type, rhs); + }, + + [&] { + // Print if-body. + Print("return absl::InvalidArgumentError(\"Expecting "); + PrintBuiltinJsonTypeCheck(builtin_type, rhs); + Println(".\");"); + }, + }); + + Println(); + break; + } + case CheckJsonType::kNo: + break; + } + + switch (action) { + case Action::kAssign: + Println("$lhs$ = $rhs$.get<$cc_type$>();"); + break; + case Action::kDef: + Println("auto $lhs$ = $rhs$.get<$cc_type$>();"); + break; + case Action::kReturn: + Println("return $rhs$.get<$cc_type$>();"); + break; + } +} + +void AstFromJsonPrinter::PrintEnumFromJson(Action action, + const EnumType &enum_type, + const Symbol &lhs, const Symbol &rhs, + absl::string_view lang_name) { + auto vars = WithVars({ + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs.ToCcVarName()}, + {"rhs_str", (rhs + "str").ToCcVarName()}, + {"EnumName", (Symbol(lang_name) + enum_type.name()).ToPascalCase()}, + }); + + const auto check = UnIndentedSource(R"( + if (!$rhs$.is_string()) { + return absl::InvalidArgumentError("`$rhs$` expected to be a string."); + } + std::string $rhs_str$ = $rhs$.get(); + )"); + Println(check); + + switch (action) { + case Action::kAssign: + Println( + "MALDOCA_ASSIGN_OR_RETURN($lhs$, StringTo$EnumName$($rhs_str$));"); + break; + case Action::kDef: + Println( + "MALDOCA_ASSIGN_OR_RETURN" + "(auto $lhs$, StringTo$EnumName$($rhs_str$));"); + break; + case Action::kReturn: + Println("return StringTo$EnumName$($rhs_str$);"); + break; + } +} + +void AstFromJsonPrinter::PrintClassFromJson(Action action, + const ClassType &class_type, + const Symbol &lhs, + const Symbol &rhs, + absl::string_view lang_name) { + auto vars = WithVars({ + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs.ToCcVarName()}, + {"Class", (Symbol(lang_name) + class_type.name()).ToPascalCase()}, + }); + + switch (action) { + case Action::kAssign: + Println("MALDOCA_ASSIGN_OR_RETURN($lhs$, $Class$::FromJson($rhs$));"); + break; + case Action::kDef: + Println( + "MALDOCA_ASSIGN_OR_RETURN(auto $lhs$, $Class$::FromJson($rhs$));"); + break; + case Action::kReturn: + Println("return $Class$::FromJson($rhs$);"); + break; + } +} + +void AstFromJsonPrinter::PrintVariantFromJson(Action action, + const VariantType &variant_type, + const Symbol &lhs, + const Symbol &rhs, + absl::string_view lang_name) { + auto vars = WithVars({ + {"cc_type", variant_type.CcType()}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs.ToCcVarName()}, + {"json_variable", kJsonValueVariableName}, + }); + + switch (action) { + case Action::kDef: + Println("$cc_type$ $lhs$;"); + break; + + case Action::kAssign: + case Action::kReturn: + break; + } + + IfStmtPrinter if_stmt_printer(this); + + Action case_action = [&] { + switch (action) { + case Action::kAssign: + case Action::kDef: + return Action::kAssign; + case Action::kReturn: + return Action::kReturn; + } + }(); + + for (const auto &scalar_type : variant_type.types()) { + if (scalar_type->IsA()) { + const auto &builtin_type = static_cast(*scalar_type); + + if_stmt_printer.PrintCase({ + [&] { + // Print if-condition. + PrintBuiltinJsonTypeCheck(builtin_type, rhs); + }, + [&] { + // Print if-body. + PrintBuiltinFromJson(case_action, CheckJsonType::kNo, builtin_type, + lhs, rhs); + }, + }); + + } else if (scalar_type->IsA()) { + const auto &class_type = static_cast(*scalar_type); + + if_stmt_printer.PrintCase({ + [&] { + // Print if-condition. + PrintClassJsonTypeCheck(class_type, rhs); + }, + [&] { + // Print if-body. + PrintClassFromJson(case_action, class_type, lhs, rhs, lang_name); + }, + }); + + } else { + LOG(FATAL) << "Unreachable code."; + } + } + + const auto handle_invalid_type = UnIndentedSource(R"( + else { + auto result = absl::InvalidArgumentError("$rhs$ has invalid type."); + result.SetPayload("json", absl::Cord{$json_variable$.dump()}); + result.SetPayload("json_element", absl::Cord{$rhs$.dump()}); + return result; + } + )"); + Println(handle_invalid_type); +} + +void AstFromJsonPrinter::PrintListFromJson(Action action, + const ListType &list_type, + const Symbol &lhs, const Symbol &rhs, + absl::string_view lang_name) { + const Symbol lhs_element = lhs + "element"; + const Symbol rhs_element = rhs + "element"; + + // Even if we are asked to assign to `lhs`, since the type of `lhs` is not + // exactly list_type.CcType(), we need to define a new variable and assign it + // to `lhs` at the end. + const Symbol lhs_defined = [&] { + switch (action) { + case Action::kDef: + return lhs; + case Action::kAssign: + return lhs + "value"; + case Action::kReturn: + return lhs; + } + }(); + + auto vars = WithVars({ + {"cc_type", list_type.CcType()}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs.ToCcVarName()}, + {"lhs_defined", lhs_defined.ToCcVarName()}, + {"lhs_element", lhs_element.ToCcVarName()}, + {"rhs_element", rhs_element.ToCcVarName()}, + }); + + const auto check_json_type = UnIndentedSource(R"( + if (!$rhs$.is_array()) { + return absl::InvalidArgumentError("$rhs$ expected to be array."); + } + )"); + Println(check_json_type); + Println(); + + Println("$cc_type$ $lhs_defined$;"); + Println("for (const nlohmann::json& $rhs_element$ : $rhs$) {"); + { + auto indent = WithIndent(); + PrintNullableFromJson(Action::kDef, list_type.element_type(), + list_type.element_maybe_null(), lhs_element, + rhs_element, lang_name); + Println("$lhs_defined$.push_back(std::move($lhs_element$));"); + } + Println("}"); + + switch (action) { + case Action::kDef: + // Nothing here. + break; + case Action::kAssign: + Println("$lhs$ = std::move($lhs_defined$);"); + break; + case Action::kReturn: + Println("return $lhs_defined$;"); + break; + } +} + +void AstFromJsonPrinter::PrintFromJson(Action action, const Type &type, + const Symbol &lhs, const Symbol &rhs, + absl::string_view lang_name) { + switch (type.kind()) { + case TypeKind::kList: { + const auto &list_type = static_cast(type); + PrintListFromJson(action, list_type, lhs, rhs, lang_name); + break; + } + + case TypeKind::kVariant: { + const auto &variant_type = static_cast(type); + PrintVariantFromJson(action, variant_type, lhs, rhs, lang_name); + break; + } + + case TypeKind::kClass: { + const auto &class_type = static_cast(type); + PrintClassFromJson(action, class_type, lhs, rhs, lang_name); + break; + } + + case TypeKind::kEnum: { + const auto &enum_type = static_cast(type); + PrintEnumFromJson(action, enum_type, lhs, rhs, lang_name); + break; + } + + case TypeKind::kBuiltin: { + const auto &builtin_type = static_cast(type); + PrintBuiltinFromJson(action, CheckJsonType::kYes, builtin_type, lhs, rhs); + break; + } + } +} + +void AstFromJsonPrinter::PrintNullableFromJson(Action action, const Type &type, + MaybeNull maybe_null, + const Symbol &lhs, + const Symbol &rhs, + absl::string_view lang_name) { + auto vars = WithVars({ + {"cc_type", type.CcType(maybe_null)}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs.ToCcVarName()}, + }); + + switch (maybe_null) { + case MaybeNull::kYes: { + switch (action) { + case Action::kDef: + Print("$cc_type$ $lhs$;\n"); + ABSL_FALLTHROUGH_INTENDED; + + case Action::kAssign: + Print("if (!$rhs$.is_null()) {\n"); + { + auto indent = WithIndent(); + PrintFromJson(Action::kAssign, type, lhs, rhs, lang_name); + } + Print("}\n"); + break; + + case Action::kReturn: { + const auto nullopt_on_null = UnIndentedSource(R"( + if ($rhs$.is_null()) { + return std::nullopt; + } + )"); + Println(nullopt_on_null); + + PrintFromJson(Action::kReturn, type, lhs, rhs, lang_name); + break; + } + } + + break; + } + + case MaybeNull::kNo: { + Println("if ($rhs$.is_null()) {"); + Println(" return absl::InvalidArgumentError(\"$rhs$ is null.\");"); + Println("}"); + + PrintFromJson(action, type, lhs, rhs, lang_name); + break; + } + } +} + +void AstFromJsonPrinter::PrintGetFieldFunction(const std::string &node_name, + const FieldDef &field, + absl::string_view lang_name) { + const Symbol json_field_name = Symbol("json") + field.name(); + + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node_name).ToPascalCase()}, + {"cc_type", CcType(field)}, + {"fieldName", field.name().ToCamelCase()}, + {"FieldName", field.name().ToPascalCase()}, + {"field_name", field.name().ToCcVarName()}, + {"field_name_it", (field.name() + "it").ToCcVarName()}, + {"json_variable", kJsonValueVariableName}, + {"json_field_name", json_field_name.ToCcVarName()}, + }); + + Println("absl::StatusOr<$cc_type$>"); + Println( + "$NodeType$::Get$FieldName$(const nlohmann::json& $json_variable$) {"); + { + auto indent = WithIndent(); + const auto status_if_undefined = UnIndentedSource(R"cc( + auto $field_name_it$ = $json_variable$.find("$fieldName$"); + if ($field_name_it$ == $json_variable$.end()) { + return absl::InvalidArgumentError("`$fieldName$` is undefined."); + } + const nlohmann::json& $json_field_name$ = $field_name_it$.value(); + )cc"); + + const auto nullopt_if_undefined = UnIndentedSource(R"cc( + auto $field_name_it$ = $json_variable$.find("$fieldName$"); + if ($field_name_it$ == $json_variable$.end()) { + return std::nullopt; + } + const nlohmann::json& $json_field_name$ = $field_name_it$.value(); + )cc"); + + switch (field.optionalness()) { + case OPTIONALNESS_UNSPECIFIED: { + LOG(FATAL) << "Invalid Optionalness. Should be a bug."; + } + case OPTIONALNESS_REQUIRED: { + Println(status_if_undefined); + Println(); + + PrintNullableFromJson(Action::kReturn, field.type(), MaybeNull::kNo, + /*lhs=*/field.name(), /*rhs=*/json_field_name, + lang_name); + break; + } + case OPTIONALNESS_MAYBE_NULL: { + Println(status_if_undefined); + Println(); + + PrintNullableFromJson(Action::kReturn, field.type(), MaybeNull::kYes, + /*lhs=*/field.name(), /*rhs=*/json_field_name, + lang_name); + break; + } + case OPTIONALNESS_MAYBE_UNDEFINED: { + Println(nullopt_if_undefined); + Println(); + + PrintNullableFromJson(Action::kReturn, field.type(), MaybeNull::kNo, + /*lhs=*/field.name(), /*rhs=*/json_field_name, + lang_name); + break; + } + } + } + Println("}"); +} + +void AstFromJsonPrinter::PrintFromJsonFunction(const NodeDef &node, + absl::string_view lang_name) { + auto vars = WithVars({ + {"NodeType", (Symbol(lang_name) + node.name()).ToPascalCase()}, + {"json_variable", kJsonValueVariableName}, + }); + + Println("absl::StatusOr>"); + Println("$NodeType$::FromJson(const nlohmann::json& $json_variable$) {"); + { + auto indent = WithIndent(); + + const auto check_is_object = UnIndentedSource(R"cc( + if (!$json_variable$.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + )cc"); + Println(check_is_object); + Println(); + + if (!node.children().empty()) { + // This is a non-leaf type. + // We get the `type` field and dispatch the corresponding FromJson() + // function. + + Println( + "MALDOCA_ASSIGN_OR_RETURN" + "(std::string type, GetType($json_variable$));"); + Println(); + + IfStmtPrinter if_stmt_printer(this); + for (const NodeDef *descendent : node.descendants()) { + auto vars = WithVars({ + {"DescendentType", + (Symbol(lang_name) + descendent->name()).ToPascalCase()}, + {"DescendentTypeNoLangName", descendent->name()}, + }); + if_stmt_printer.PrintCase({ + [&] { Print("type == \"$DescendentTypeNoLangName$\""); }, + [&] { + Println("return $DescendentType$::FromJson($json_variable$);"); + }, + }); + } + Println(); + + Print("return absl::InvalidArgumentError"); + Println(R"((absl::StrCat("Invalid type: ", type));)"); + + } else { + // This is a leaf type. + // We get all the fields and call the constructor. + + struct NodeFieldPair { + std::string node_name; + Symbol field_name; + }; + std::vector node_field_pairs; + for (const NodeDef *ancestor : node.ancestors()) { + for (const FieldDef &field : ancestor->fields()) { + node_field_pairs.push_back({ancestor->name(), field.name()}); + } + } + for (const FieldDef &field : node.fields()) { + node_field_pairs.push_back({node.name(), field.name()}); + } + + for (const NodeFieldPair &node_field_pair : node_field_pairs) { + auto vars = WithVars({ + {"NodeType", + (Symbol(lang_name) + node_field_pair.node_name).ToPascalCase()}, + {"field_name", node_field_pair.field_name.ToCcVarName()}, + {"FieldName", node_field_pair.field_name.ToPascalCase()}, + }); + Println( + "MALDOCA_ASSIGN_OR_RETURN(auto $field_name$, " + "$NodeType$::Get$FieldName$($json_variable$));"); + } + + Println(); + + // Call the constructor. + Print("return absl::make_unique<$NodeType$>(\n"); + { + auto indent = WithIndent(4); + TabPrinter tab_printer{{ + .print_separator = [this] { Print(",\n"); }, + }}; + for (const FieldDef *field : node.aggregated_fields()) { + auto vars = WithVars({ + {"field_name", field->name().ToCcVarName()}, + }); + + tab_printer.Print(); + Print("std::move($field_name$)"); + } + } + + Println(");"); + } + } + Println("}"); +} + +std::string PrintAstFromJson(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path) { + std::string str; + { + google::protobuf::io::StringOutputStream os(&str); + AstFromJsonPrinter printer(&os); + printer.PrintAst(ast, cc_namespace, ast_path); + } + + return str; +} + +// ============================================================================= +// IrTableGenPrinter +// ============================================================================= + +void IrTableGenPrinter::PrintAst(const AstDef &ast, absl::string_view ir_path) { + PrintLicense(); + Println(); + + PrintCodeGenerationWarning(); + Println(); + + // E.g. lang_name == "js", then ir_name == "jsir". + const auto ir_name = absl::StrCat(ast.lang_name(), "ir"); + + // E.g. "/jsir_ops.generated.td". + const auto td_path = absl::StrCat(ir_path, "/", ir_name, "_ops.generated.td"); + + PrintEnterHeaderGuard(td_path); + Println(); + + std::vector imports = { + "mlir/Interfaces/ControlFlowInterfaces.td", + "mlir/Interfaces/InferTypeOpInterface.td", + "mlir/Interfaces/LoopLikeInterface.td", + "mlir/Interfaces/SideEffectInterfaces.td", + "mlir/IR/OpBase.td", + "mlir/IR/SymbolInterfaces.td", + absl::StrCat(ir_path, "/interfaces.td"), + absl::StrCat(ir_path, "/", ast.lang_name(), "ir_dialect.td"), + absl::StrCat(ir_path, "/", ast.lang_name(), "ir_types.td"), + }; + for (const auto &import : imports) { + Println(absl::StrCat("include \"", import, "\"")); + } + Println(); + + bool has_expr_region = false; + bool has_exprs_region = false; + for (const auto *node : ast.topological_sorted_nodes()) { + for (const auto *field : node->aggregated_fields()) { + if (!field->enclose_in_region()) { + continue; + } + if (field->kind() != FIELD_KIND_LVAL && + field->kind() != FIELD_KIND_RVAL) { + continue; + } + if (field->type().IsA()) { + has_exprs_region = true; + } else { + has_expr_region = true; + } + } + } + + const auto region_end_comment = UnIndentedSource(R"( +// $ir$.*_region_end: An artificial op at the end of a region to collect +// expression-related values. +// +// Take $ir$.exprs_region_end as example: +// ====================================== +// +// Consider the following function declaration: +// ``` +// function foo(arg1, arg2 = defaultValue) { +// ... +// } +// ``` +// +// We lower it to the following IR (simplified): +// ``` +// %0 = $ir$.identifier_ref {"foo"} +// $ir$.function_declaration(%0) ( +// // params +// { +// %1 = $ir$.identifier_ref {"a"} +// %2 = $ir$.identifier_ref {"b"} +// %3 = $ir$.identifier {"defaultValue"} +// %4 = $ir$.assignment_pattern_ref(%2, %3) +// $ir$.exprs_region_end(%1, %4) +// }, +// // body +// { +// ... +// } +// ) +// ``` +// +// We can see that: +// +// 1. We put the parameter-related ops in a region, instead of taking them as +// normal arguments. In other words, we don't do this: +// +// ``` +// %0 = $ir$.identifier_ref {"foo"} +// %1 = $ir$.identifier_ref {"a"} +// %2 = $ir$.identifier_ref {"b"} +// %3 = $ir$.identifier {"defaultValue"} +// %4 = $ir$.assignment_pattern_ref(%2, %3) +// $ir$.function_declaration(%0, [%1, %4]) ( +// // body +// { +// ... +// } +// ) +// ``` +// +// The reason is that sometimes an argument might have a default value, and +// the evaluation of that default value happens once for each function call +// (i.e. it happens "within" the function). If we take the parameter as +// normal argument, then %3 is only evaluated once - at function definition +// time. +// +// 2. Even though the function has two parameters, we use 4 ops to represent +// them. This is because some parameters are more complex and require more +// than one op. +// +// 3. We use "$ir$.exprs_region_end" to list the "top-level" ops for the +// parameters. In the example above, ops [%2, %3, %4] all represent the +// parameter "b = defaultValue", but %4 is the top-level one. In other words, +// %4 is the root of the tree [%2, %3, %4]. +// +// 4. Strictly speaking, we don't really need "$ir$.exprs_region_end". The ops +// within the "params" region form several trees, and we can figure out what +// the roots are (a root is an op whose return value is not used by any other +// op). So the use of "$ir$.exprs_region_end" is mostly for convenience. + )"); + + if (has_expr_region || has_exprs_region) { + Symbol ir{absl::StrCat(ast.lang_name(), "ir")}; + + auto vars = WithVars({ + {"ir", ir.ToSnakeCase()}, + {"Ir", ir.ToPascalCase()}, + }); + Println(region_end_comment); + + if (has_expr_region) { + const auto expr_region_end = UnIndentedSource(R"( + def $Ir$ExprRegionEndOp : $Ir$_Op<"expr_region_end", [Terminator]> { + let arguments = (ins + AnyType: $$argument + ); + } + )"); + Println(expr_region_end); + Println(); + } + + if (has_exprs_region) { + const auto exprs_region_end = UnIndentedSource(R"( + def $Ir$ExprsRegionEndOp : $Ir$_Op<"exprs_region_end", [Terminator]> { + let arguments = (ins + Variadic: $$arguments + ); + } + )"); + Println(exprs_region_end); + Println(); + } + } + + for (const auto *node : ast.topological_sorted_nodes()) { + if (!node->should_generate_ir_op()) { + continue; + } + + for (auto kind : node->aggregated_kinds()) { + PrintNode(ast, *node, kind); + } + } + + PrintExitHeaderGuard(td_path); +} + +void IrTableGenPrinter::PrintNode(const AstDef &ast, const NodeDef &node, + FieldKind kind) { + auto ir_name = absl::StrCat(ast.lang_name(), "ir"); + auto hir_name = + absl::StrCat(ast.lang_name(), node.has_control_flow() ? "hir" : "ir"); + + auto vars = WithVars({ + {"OpName", node.ir_op_name(ast.lang_name(), kind).value().ToPascalCase()}, + {"op_mnemonic", node.ir_op_mnemonic(kind).value().ToCcVarName()}, + {"Name", node.name()}, + {"name", Symbol(node.name()).ToCcVarName()}, + {"IrName", Symbol(ir_name).ToPascalCase()}, + {"HirName", Symbol(hir_name).ToPascalCase()}, + }); + + std::vector traits; + for (const NodeDef *parent : node.parents()) { + if (!absl::c_linear_search(parent->aggregated_kinds(), kind)) { + continue; + } + auto parent_ir_op_name = parent->ir_op_name(ast.lang_name(), kind); + if (!parent_ir_op_name.has_value()) { + continue; + } + traits.push_back(*parent_ir_op_name + "Traits"); + } + + // When there is more than one variadic operand, we must append the + // AttrSizedOperandSegments trait. This is because MLIR internally stores + // operands as a single array and without additional information, it cannot + // attributes ranges of that array into the corresponding variadic operands. + // + // MLIR doesn't allow universally adding AttrSizedOperandSegments - only ops + // with more than one variadic operand are allowed. + // + // See: https://mlir.llvm.org/docs/OpDefinitions/#variadic-operands + size_t num_variadic_operands = 0; + for (const FieldDef &field : node.fields()) { + if (field.enclose_in_region()) { + continue; + } + + switch (field.kind()) { + case FIELD_KIND_UNSPECIFIED: { + LOG(QFATAL) << node.name() << "::" << field.name().ToCcVarName() + << ": FieldKind unspecified."; + } + case FIELD_KIND_ATTR: + case FIELD_KIND_STMT: { + break; + } + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: { + if (field.type().IsA() || + field.optionalness() == OPTIONALNESS_MAYBE_NULL || + field.optionalness() == OPTIONALNESS_MAYBE_UNDEFINED) { + num_variadic_operands++; + } + } + } + } + if (num_variadic_operands > 1) { + traits.push_back(Symbol("AttrSizedOperandSegments")); + } + + if (absl::c_any_of(node.aggregated_fields(), FieldIsRegion)) { + traits.push_back(Symbol("NoTerminator")); + } + + for (auto mlir_trait : node.aggregated_additional_mlir_traits()) { + switch (mlir_trait) { + case MLIR_TRAIT_INVALID: + LOG(FATAL) << "Invalid MlirTrait."; + case MLIR_TRAIT_PURE: + traits.push_back(Symbol("Pure")); + break; + case MLIR_TRAIT_ISOLATED_FROM_ABOVE: + traits.push_back(Symbol("IsolatedFromAbove")); + break; + } + } + + if (traits.empty()) { + Println("def $OpName$ : $IrName$_Op<\"$op_mnemonic$\", []> {"); + } else { + // Example: + // ``` + // def JsirBinaryExpressionOp : Jsir_Op< + // "binary_expression", [ + // DeclareOpInterfaceMethods, + // DeclareOpInterfaceMethods + // ]> { + // ``` + Print( + "def $OpName$ : $HirName$_Op<\n" + " \"$op_mnemonic$\", [\n"); + + { + auto indent = WithIndent(8); + TabPrinter tab_printer{{ + .print_separator = [&] { Print(",\n"); }, + }}; + + for (const Symbol &trait : traits) { + auto vars = WithVars({ + {"Trait", trait.ToPascalCase()}, + }); + + tab_printer.Print(); + Print("$Trait$"); + } + } + + Println("\n ]> {"); + } + { + auto indent = WithIndent(); + TabPrinter line_separator_printer{{ + .print_separator = [&] { Print("\n"); }, + }}; + if (node.has_fold()) { + line_separator_printer.Print(); + Println("let hasFolder = 1;"); + } + + if (absl::c_any_of(node.aggregated_fields(), FieldIsArgument)) { + line_separator_printer.Print(); + + Println("let arguments = (ins"); + { + auto indent = WithIndent(); + TabPrinter separator_printer{{ + .print_separator = [&] { Print(",\n"); }, + }}; + for (const auto *field : node.aggregated_fields()) { + if (!FieldIsArgument(field)) { + continue; + } + + separator_printer.Print(); + PrintArgument(ast, node, *field); + } + } + Println(); + Println(");"); + } + + if (absl::c_any_of(node.aggregated_fields(), FieldIsRegion)) { + line_separator_printer.Print(); + + Println("let regions = (region"); + { + auto indent = WithIndent(); + TabPrinter separator_printer{{ + .print_separator = [&] { Print(",\n"); }, + }}; + for (const auto *field : node.aggregated_fields()) { + if (!FieldIsRegion(field)) { + continue; + } + + separator_printer.Print(); + PrintRegion(ast, node, *field); + } + } + Println(); + Println(");"); + } + + // Only expressions have results. + if (kind == FIELD_KIND_LVAL || kind == FIELD_KIND_RVAL) { + line_separator_printer.Print(); + + Println("let results = (outs"); + Println(" $IrName$AnyType"); + Println(");"); + } + } + + Println("}"); + Println(); +} + +void IrTableGenPrinter::PrintArgument(const AstDef &ast, const NodeDef &node, + const FieldDef &field) { + auto vars = WithVars({ + {"type", field.type().TdType(field.optionalness(), field.kind())}, + {"name", field.name().ToCcVarName()}, + }); + Print("$type$: $$$name$"); +} + +void IrTableGenPrinter::PrintRegion(const AstDef &ast, const NodeDef &node, + const FieldDef &field) { + std::string region_type = [&] { + switch (field.kind()) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "FieldKind is unspecified."; + case FIELD_KIND_ATTR: + LOG(FATAL) << "Region of attributes not supported."; + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + if (field.type().IsA()) { + return "ExprsRegion"; + } else { + return "ExprRegion"; + } + case FIELD_KIND_STMT: + if (field.type().IsA()) { + return "StmtsRegion"; + } else { + return "StmtRegion"; + } + } + }(); + + switch (field.optionalness()) { + case OPTIONALNESS_UNSPECIFIED: + LOG(FATAL) << "Optionalness unspecified."; + case OPTIONALNESS_REQUIRED: + break; + case OPTIONALNESS_MAYBE_NULL: + case OPTIONALNESS_MAYBE_UNDEFINED: + region_type = absl::StrCat("OptionalRegion<", region_type, ">"); + } + + auto vars = WithVars({ + {"name", field.name().ToCcVarName()}, + {"RegionType", region_type}, + }); + + Print("$RegionType$: $$$name$"); +} + +std::string PrintIrTableGen(const AstDef &ast, absl::string_view ir_path) { + std::string str; + { + google::protobuf::io::StringOutputStream os(&str); + IrTableGenPrinter printer(&os); + printer.PrintAst(ast, ir_path); + } + + return str; +} + +// ============================================================================= +// AstToIrSourcePrinter +// ============================================================================= + +void AstToIrSourcePrinter::PrintAst(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path, + absl::string_view ir_path) { + auto ast_header_path = GetAstHeaderPath(ast_path); + + PrintLicense(); + Println(); + + PrintCodeGenerationWarning(); + Println(); + + Println("// IWYU pragma: begin_keep"); + Println("// NOLINTBEGIN(whitespace/line_length)"); + Println("// clang-format off"); + Println(); + + PrintIncludeHeader( + absl::StrCat(ir_path, "/conversion/ast_to_", ast.lang_name(), "ir.h")); + Println(); + + Println("#include "); + Println("#include "); + Println("#include "); + Println(); + + PrintIncludeHeaders({ + "llvm/ADT/APFloat.h", + "mlir/IR/Attributes.h", + "mlir/IR/Block.h", + "mlir/IR/Builders.h", + "mlir/IR/BuiltinAttributes.h", + "mlir/IR/BuiltinTypes.h", + "mlir/IR/Operation.h", + "mlir/IR/Region.h", + "mlir/IR/Value.h", + "absl/cleanup/cleanup.h", + "absl/log/check.h", + "absl/log/log.h", + "absl/types/optional.h", + "absl/types/variant.h", + std::string(ast_header_path), + absl::StrCat(ir_path, "/ir.h"), + }); + Println(); + + PrintEnterNamespace(cc_namespace); + Println(); + + for (const auto *node : ast.topological_sorted_nodes()) { + if (!node->children().empty()) { + for (FieldKind kind : node->aggregated_kinds()) { + PrintNonLeafNode(ast, *node, kind); + } + } + + if (!node->should_generate_ir_op()) { + continue; + } + + for (FieldKind kind : node->aggregated_kinds()) { + PrintLeafNode(ast, *node, kind); + } + } + + Println("// clang-format on"); + Println("// NOLINTEND(whitespace/line_length)"); + Println("// IWYU pragma: end_keep"); + Println(); + + PrintExitNamespace(cc_namespace); +} + +static Symbol GetVisitor(const NodeDef &node, FieldKind kind) { + auto visitor = Symbol("Visit") + node.name(); + if (kind == FIELD_KIND_ATTR) { + visitor += "Attr"; + } + if (kind == FIELD_KIND_LVAL) { + visitor += "Ref"; + } + return visitor; +} + +void AstToIrSourcePrinter::PrintNonLeafNode(const AstDef &ast, + const NodeDef &node, + FieldKind kind) { + auto ir_op_name = node.ir_op_name(ast.lang_name(), kind); + std::string return_type; + if (ir_op_name.has_value()) { + return_type = ir_op_name.value().ToPascalCase(); + } else { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Invalid FieldKind: FIELD_KIND_UNSPECIFIED."; + case FIELD_KIND_ATTR: { + return_type = "mlir::Attribute"; + break; + } + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: { + return_type = "mlir::Value"; + break; + } + case FIELD_KIND_STMT: { + return_type = "mlir::Operation*"; + break; + } + } + } + auto ir_name = Symbol(absl::StrCat(ast.lang_name(), "ir")); + auto visitor = GetVisitor(node, kind); + + auto vars = WithVars({ + {"Ret", return_type}, + {"Name", (Symbol(ast.lang_name()) + node.name()).ToPascalCase()}, + {"IrName", ir_name.ToPascalCase()}, + {"Visitor", visitor.ToPascalCase()}, + }); + + Println("$Ret$ AstTo$IrName$::$Visitor$(const $Name$ *node) {"); + { + auto indent = WithIndent(); + for (const NodeDef *leaf : node.leafs()) { + auto vars = WithVars({ + {"LeafName", (Symbol(ast.lang_name()) + leaf->name()).ToPascalCase()}, + {"leaf_name", Symbol(leaf->name()).ToCcVarName()}, + {"LeafVisitor", GetVisitor(*leaf, kind).ToPascalCase()}, + }); + Println( + "if (auto *$leaf_name$ = dynamic_cast(node)) {"); + Println(" return $LeafVisitor$($leaf_name$);"); + Println("}"); + } + + Println("LOG(FATAL) << \"Unreachable code.\";"); + } + Println("}"); + Println(); +} + +void AstToIrSourcePrinter::PrintLeafNode(const AstDef &ast, const NodeDef &node, + FieldKind kind) { + auto ir_op_name = node.ir_op_name(ast.lang_name(), kind).value(); + auto ir_name = Symbol(absl::StrCat(ast.lang_name(), "ir")); + + auto visitor = Symbol("Visit") + node.name(); + if (kind == FIELD_KIND_LVAL) { + visitor += "Ref"; + } + + auto creator = Symbol("Create"); + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_ATTR: + LOG(FATAL) << "Unsupported kind: " << kind; + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + creator += "Expr"; + break; + case FIELD_KIND_STMT: + creator += "Stmt"; + break; + } + + auto vars = WithVars({ + {"OpName", ir_op_name.ToPascalCase()}, + {"Name", (Symbol(ast.lang_name()) + node.name()).ToPascalCase()}, + {"IrName", ir_name.ToPascalCase()}, + {"Visitor", visitor.ToPascalCase()}, + {"Creator", creator.ToPascalCase()}, + }); + + Println("$OpName$ AstTo$IrName$::$Visitor$(const $Name$ *node) {"); + { + auto indent = WithIndent(); + + for (const auto *field : node.aggregated_fields()) { + if (!FieldIsArgument(field)) { + continue; + } + + PrintField(ast, node, *field); + } + + bool has_regions = absl::c_any_of(node.aggregated_fields(), FieldIsRegion); + if (has_regions) { + Print("auto op = "); + } else { + Print("return "); + } + + Print("$Creator$<$OpName$>(node"); + { + auto indent = WithIndent(4); + for (const auto *field : node.aggregated_fields()) { + if (!FieldIsArgument(field)) { + continue; + } + + const auto mlir_field_name = (Symbol("mlir") + field->name()); + auto vars = WithVars({ + {"mlir_field_name", mlir_field_name.ToCcVarName()}, + }); + + Print(", $mlir_field_name$"); + } + } + Println(");"); + + if (has_regions) { + for (const auto *field : node.aggregated_fields()) { + if (FieldIsRegion(field)) { + PrintRegion(ast, node, *field); + } + } + + Println("return op;"); + } + } + + Println("}"); + Println(); +} + +void AstToIrSourcePrinter::PrintField(const AstDef &ast, const NodeDef &node, + const FieldDef &field) { + MaybeNull maybe_null = OptionalnessToMaybeNull(field.optionalness()); + + auto lhs = Symbol("mlir") + field.name(); + auto rhs = absl::StrCat("node->", field.name().ToCcVarName(), "()"); + PrintNullableToIr(ast, Action::kDef, field.type(), maybe_null, RefOrVal::kRef, + field.kind(), lhs, rhs); +} + +void AstToIrSourcePrinter::PrintRegion(const AstDef &ast, const NodeDef &node, + const FieldDef &field) { + MaybeNull maybe_null = OptionalnessToMaybeNull(field.optionalness()); + + auto lhs = Symbol("mlir") + field.name(); + auto lhs_region = lhs + "region"; + auto rhs = absl::StrCat("node->", field.name().ToCcVarName(), "()"); + auto ir_name = Symbol(absl::StrCat(ast.lang_name(), "ir")); + + auto vars = WithVars({ + {"lhs", lhs.ToCcVarName()}, + {"lhs_region", lhs_region.ToCcVarName()}, + {"mlirGetter", field.name().ToMlirGetter()}, + {"rhs", rhs}, + }); + + auto populate_region = [&] { + Println("mlir::Region &$lhs_region$ = op.$mlirGetter$();"); + Println("AppendNewBlockAndPopulate($lhs_region$, [&] {"); + { + auto indent = WithIndent(); + + Action action = [&] { + switch (field.kind()) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + LOG(FATAL) << "Unsupported FieldKind: " << field.kind(); + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: { + return Action::kDef; + } + case FIELD_KIND_STMT: { + return Action::kCreate; + } + } + }(); + + Symbol region_end_op = GetRegionEndOp(ast, field); + PrintToIr(ast, action, field.type(), RefOrVal::kRef, field.kind(), lhs, + rhs); + + auto vars = WithVars({ + {"RegionEndOp", region_end_op.ToPascalCase()}, + }); + + switch (action) { + case Action::kAssign: + LOG(FATAL) << "Unsupported Action: Assign."; + case Action::kCreate: + break; + case Action::kDef: { + Println("CreateStmt<$RegionEndOp$>(node, $lhs$);"); + break; + } + } + } + Println("});"); + }; + + switch (maybe_null) { + case MaybeNull::kYes: { + Println("if ($rhs$.has_value()) {"); + { + auto indent = WithIndent(); + absl::StrAppend(&rhs, ".value()"); + auto vars = WithVars({ + {"rhs", rhs}, + }); + populate_region(); + } + Println("}"); + break; + } + case MaybeNull::kNo: + populate_region(); + break; + } +} + +void AstToIrSourcePrinter::PrintBuiltinToIr(const AstDef &ast, Action action, + const BuiltinType &type, + const Symbol &lhs, + const std::string &rhs) { + auto vars = WithVars({ + {"mlir_type", type.CcMlirBuilderType(FIELD_KIND_ATTR)}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs}, + }); + + switch (action) { + case Action::kDef: + Print("$mlir_type$ "); + ABSL_FALLTHROUGH_INTENDED; + case Action::kAssign: + Print("$lhs$ = "); + break; + case Action::kCreate: + break; + } + + switch (type.builtin_kind()) { + case BuiltinTypeKind::kBool: { + Print("builder_.getBoolAttr($rhs$)"); + break; + } + case BuiltinTypeKind::kInt64: { + Print("builder_.getI64IntegerAttr($rhs$)"); + break; + } + case BuiltinTypeKind::kString: { + Print("builder_.getStringAttr($rhs$)"); + break; + } + case BuiltinTypeKind::kDouble: { + Print("builder_.getF64FloatAttr($rhs$)"); + break; + } + } + + Println(";"); +} + +void AstToIrSourcePrinter::PrintClassToIr(const AstDef &ast, Action action, + const ClassType &type, FieldKind kind, + const Symbol &lhs, + const std::string &rhs) { + auto vars = WithVars({ + {"ClassName", type.name().ToPascalCase()}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs}, + }); + + switch (action) { + case Action::kDef: { + auto vars = WithVars({ + {"cc_mlir_type", type.CcMlirBuilderType(kind)}, + }); + Print("$cc_mlir_type$ "); + ABSL_FALLTHROUGH_INTENDED; + } + case Action::kAssign: + Print("$lhs$ = "); + break; + case Action::kCreate: + break; + } + + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_ATTR: + Println("Visit$ClassName$Attr($rhs$);"); + break; + case FIELD_KIND_RVAL: + case FIELD_KIND_STMT: { + Println("Visit$ClassName$($rhs$);"); + break; + } + case FIELD_KIND_LVAL: { + Println("Visit$ClassName$Ref($rhs$);"); + break; + } + } +} + +void AstToIrSourcePrinter::PrintClassToIr(const AstDef &ast, Action action, + const ClassType &type, + RefOrVal ref_or_val, FieldKind kind, + const Symbol &lhs, + const std::string &rhs) { + switch (ref_or_val) { + case RefOrVal::kRef: + return PrintClassToIr(ast, action, type, kind, lhs, rhs); + case RefOrVal::kVal: + return PrintClassToIr(ast, action, type, kind, lhs, + absl::StrCat(rhs, ".get()")); + } +} + +void AstToIrSourcePrinter::PrintEnumToIr(const AstDef &ast, Action action, + const EnumType &type, + const Symbol &lhs, + const std::string &rhs) { + auto enum_name = (Symbol(ast.lang_name()) + type.name()).ToPascalCase(); + auto rhs_str = absl::StrCat(enum_name, "ToString(", rhs, ")"); + + BuiltinType string_type{BuiltinTypeKind::kString, ast.lang_name()}; + return PrintBuiltinToIr(ast, action, string_type, lhs, rhs_str); +} + +void AstToIrSourcePrinter::PrintVariantToIr(const AstDef &ast, Action action, + const VariantType &type, + RefOrVal ref_or_val, FieldKind kind, + const Symbol &lhs, + const std::string &rhs) { + auto vars = WithVars({ + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs}, + }); + + Action case_action; + switch (action) { + case Action::kDef: { + auto vars = WithVars({ + {"cc_mlir_type", type.CcMlirBuilderType(kind)}, + }); + Println("$cc_mlir_type$ $lhs$;"); + case_action = Action::kAssign; + break; + } + case Action::kAssign: + case_action = Action::kAssign; + break; + case Action::kCreate: + case_action = Action::kCreate; + break; + } + + Println("switch ($rhs$.index()) {"); + { + auto indent = WithIndent(); + + for (size_t i = 0; i != type.types().size(); ++i) { + auto vars = WithVars({ + {"i", std::to_string(i)}, + }); + + Println("case $i$: {"); + { + auto indent = WithIndent(); + const ScalarType &scalar_type = *type.types()[i]; + PrintToIr(ast, case_action, scalar_type, ref_or_val, kind, lhs, + absl::StrFormat("std::get<%zu>(%s)", i, rhs)); + Println("break;"); + } + + Println("}"); + } + + Println("default:"); + Println(" LOG(FATAL) << \"Unreachable code.\";"); + } + Println("}"); +} + +void AstToIrSourcePrinter::PrintListToIr(const AstDef &ast, Action action, + const ListType &type, FieldKind kind, + const Symbol &lhs, + const std::string &rhs) { + const auto lhs_element = Symbol("mlir_element"); + const auto rhs_element = "element"; + + auto vars = WithVars({ + {"lhs", lhs.ToCcVarName()}, + {"lhs_data", (lhs + "data").ToCcVarName()}, + {"rhs", rhs}, + {"lhs_element", lhs_element.ToCcVarName()}, + {"rhs_element", rhs_element}, + }); + + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "FieldKind unspecified."; + case FIELD_KIND_STMT: { + // Case: List of Statements. + CHECK(action == Action::kCreate) + << "We never collect statement ops in a vector."; + + Println("for (const auto &$rhs_element$ : *$rhs$) {"); + { + auto indent = WithIndent(); + PrintNullableToIr(ast, Action::kCreate, type.element_type(), + type.element_maybe_null(), RefOrVal::kVal, kind, + lhs_element, rhs_element); + } + Println("}"); + break; + } + case FIELD_KIND_ATTR: { + // Case: List of Attributes. + // + // We first create and fill a std::vector and then + // convert it into a mlir::ArrayAttr (what the builder takes). + + Println("std::vector $lhs_data$;"); + Println("for (const auto &$rhs_element$ : *$rhs$) {"); + { + auto indent = WithIndent(); + PrintNullableToIr(ast, Action::kDef, type.element_type(), + type.element_maybe_null(), RefOrVal::kVal, kind, + lhs_element, rhs_element); + Println("$lhs_data$.push_back(std::move($lhs_element$));"); + } + Println("}"); + + switch (action) { + case Action::kDef: { + Println("auto $lhs$ = builder_.getArrayAttr($lhs_data$);"); + break; + } + case Action::kAssign: { + Println("$lhs$ = builder_.getArrayAttr($lhs_data$);"); + break; + } + case Action::kCreate: + LOG(FATAL) << "We never put attributes in a region."; + } + break; + } + + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: { + // Case: List of Values. + // + // We create and fill a std::vector which can be implicitly + // converted to a mlir::ValueRange (what the builder takes). + + switch (action) { + case Action::kDef: + Println("std::vector $lhs$;"); + break; + case Action::kAssign: + // Do nothing. + break; + case Action::kCreate: + LOG(FATAL) << "We must put expressions in a vector."; + } + + Println("for (const auto &$rhs_element$ : *$rhs$) {"); + { + auto indent = WithIndent(); + switch (type.element_maybe_null()) { + case MaybeNull::kNo: { + PrintToIr(ast, Action::kDef, type.element_type(), RefOrVal::kVal, + kind, lhs_element, rhs_element); + break; + } + + case MaybeNull::kYes: { + // Unfortunately, in the std::vector we can't have any + // nullptr. In order to represent optional, we need the special + // irNoneOp. + + Println("mlir::Value $lhs_element$;"); + Println("if ($rhs_element$.has_value()) {"); + { + auto indent = WithIndent(); + PrintToIr(ast, Action::kAssign, type.element_type(), + RefOrVal::kVal, kind, lhs_element, + absl::StrCat(rhs_element, ".value()")); + } + Println("} else {"); + { + auto indent = WithIndent(); + auto none_op = + Symbol(absl::StrCat(ast.lang_name(), "ir")) + "NoneOp"; + auto vars = WithVars({ + {"NoneOp", none_op.ToPascalCase()}, + }); + + Println("$lhs_element$ = CreateExpr<$NoneOp$>(node);"); + } + Println("}"); + + break; + } + } + + Println("$lhs$.push_back(std::move($lhs_element$));"); + } + + Println("}"); + } + } +} + +void AstToIrSourcePrinter::PrintToIr(const AstDef &ast, Action action, + const Type &type, RefOrVal ref_or_val, + FieldKind kind, const Symbol &lhs, + const std::string &rhs) { + switch (type.kind()) { + case TypeKind::kBuiltin: { + const auto &builtin_type = static_cast(type); + return PrintBuiltinToIr(ast, action, builtin_type, lhs, rhs); + } + + case TypeKind::kClass: { + const auto &class_type = static_cast(type); + return PrintClassToIr(ast, action, class_type, ref_or_val, kind, lhs, + rhs); + } + + case TypeKind::kEnum: { + const auto &enum_type = static_cast(type); + return PrintEnumToIr(ast, action, enum_type, lhs, rhs); + } + + case TypeKind::kVariant: { + const auto &variant_type = static_cast(type); + return PrintVariantToIr(ast, action, variant_type, ref_or_val, kind, lhs, + rhs); + } + + case TypeKind::kList: { + const auto &list_type = static_cast(type); + CHECK(ref_or_val == RefOrVal::kRef); + return PrintListToIr(ast, action, list_type, kind, lhs, rhs); + } + } +} + +void AstToIrSourcePrinter::PrintNullableToIr(const AstDef &ast, Action action, + const Type &type, + MaybeNull maybe_null, + RefOrVal ref_or_val, + FieldKind kind, const Symbol &lhs, + const std::string &rhs) { + auto vars = WithVars({ + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs}, + }); + + switch (maybe_null) { + case MaybeNull::kYes: { + Action non_null_action; + switch (action) { + case Action::kAssign: + non_null_action = Action::kAssign; + break; + case Action::kCreate: + non_null_action = Action::kCreate; + break; + case Action::kDef: { + auto vars = WithVars({ + {"mlir_type", type.CcMlirBuilderType(kind)}, + }); + Println("$mlir_type$ $lhs$;"); + non_null_action = Action::kAssign; + break; + } + } + Println("if ($rhs$.has_value()) {"); + { + auto indent = WithIndent(); + auto new_rhs = absl::StrCat(rhs, ".value()"); + PrintToIr(ast, non_null_action, type, ref_or_val, kind, lhs, new_rhs); + } + Println("}"); + break; + } + + case MaybeNull::kNo: { + PrintToIr(ast, action, type, ref_or_val, kind, lhs, rhs); + break; + } + } +} + +std::string PrintAstToIrSource(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path, + absl::string_view ir_path) { + std::string str; + { + google::protobuf::io::StringOutputStream os(&str); + AstToIrSourcePrinter printer(&os); + printer.PrintAst(ast, cc_namespace, ast_path, ir_path); + } + + return str; +} + +// ============================================================================= +// IrToAstSourcePrinter +// ============================================================================= + +void IrToAstSourcePrinter::PrintAst(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path, + absl::string_view ir_path) { + auto ast_header_path = GetAstHeaderPath(ast_path); + + PrintLicense(); + Println(); + + PrintCodeGenerationWarning(); + Println(); + + Println("// IWYU pragma: begin_keep"); + Println("// NOLINTBEGIN(whitespace/line_length)"); + Println("// clang-format off"); + Println(); + + PrintIncludeHeader( + absl::StrCat(ir_path, "/conversion/", ast.lang_name(), "ir_to_ast.h")); + Println(); + + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println("#include "); + Println(); + + PrintIncludeHeaders({ + "llvm/ADT/APFloat.h", + "llvm/ADT/TypeSwitch.h", + "llvm/Support/Casting.h", + "mlir/IR/Attributes.h", + "mlir/IR/Block.h", + "mlir/IR/Builders.h", + "mlir/IR/BuiltinAttributes.h", + "mlir/IR/BuiltinTypes.h", + "mlir/IR/Operation.h", + "mlir/IR/Region.h", + "mlir/IR/Value.h", + "absl/cleanup/cleanup.h", + "absl/log/check.h", + "absl/log/log.h", + "absl/status/status.h", + "absl/status/statusor.h", + "absl/strings/str_cat.h", + "absl/types/optional.h", + "absl/types/variant.h", + "maldoca/base/status_macros.h", + std::string(ast_header_path), + absl::StrCat(ir_path, "/ir.h"), + }); + Println(); + + PrintEnterNamespace(cc_namespace); + Println(); + + for (const auto *node : ast.topological_sorted_nodes()) { + if (!node->children().empty()) { + for (FieldKind kind : node->aggregated_kinds()) { + PrintNonLeafNode(ast, *node, kind); + } + } + + if (!node->should_generate_ir_op()) { + continue; + } + + for (FieldKind kind : node->aggregated_kinds()) { + PrintLeafNode(ast, *node, kind); + } + } + + Println("// clang-format on"); + Println("// NOLINTEND(whitespace/line_length)"); + Println("// IWYU pragma: end_keep"); + Println(); + + PrintExitNamespace(cc_namespace); +} + +void IrToAstSourcePrinter::PrintNonLeafNode(const AstDef &ast, + const NodeDef &node, + FieldKind kind) { + auto ir_op_name = node.ir_op_name(ast.lang_name(), kind); + std::string input_type; + if (ir_op_name.has_value()) { + input_type = ir_op_name->ToPascalCase(); + } else { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Invalid FieldKind: FIELD_KIND_UNSPECIFIED."; + case FIELD_KIND_ATTR: + input_type = "mlir::Attribute"; + break; + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + case FIELD_KIND_STMT: + input_type = "mlir::Operation*"; + break; + } + } + auto ir_name = Symbol(absl::StrCat(ast.lang_name(), "ir")); + auto visitor = GetVisitor(node, kind); + + auto vars = WithVars({ + {"InputType", input_type}, + {"BaseName", + kind == FIELD_KIND_ATTR ? "mlir::Attribute" : "mlir::Operation*"}, + {"Name", (Symbol(ast.lang_name()) + node.name()).ToPascalCase()}, + {"name", kind == FIELD_KIND_ATTR ? "attr" : "op"}, + {"IrName", ir_name.ToPascalCase()}, + {"Visitor", visitor.ToPascalCase()}, + }); + + Println("absl::StatusOr>"); + Println("$IrName$ToAst::$Visitor$($InputType$ $name$) {"); + { + auto indent = WithIndent(); + Println("using Ret = absl::StatusOr>;"); + Println("return llvm::TypeSwitch<$BaseName$, Ret>($name$)"); + { + auto indent = WithIndent(); + for (const NodeDef *leaf : node.leafs()) { + auto vars = WithVars({ + {"LeafOpName", + leaf->ir_op_name(ast.lang_name(), kind)->ToPascalCase()}, + {"LeafVisitor", GetVisitor(*leaf, kind).ToPascalCase()}, + }); + Println(".Case([&]($LeafOpName$ $name$) {"); + Println(" return $LeafVisitor$($name$);"); + Println("})"); + } + + Println(".Default([&]($BaseName$ op) {"); + Println(" return absl::InvalidArgumentError(\"Unrecognized op\");"); + Println("});"); + } + } + Println("}"); + Println(); +} + +void IrToAstSourcePrinter::PrintLeafNode(const AstDef &ast, const NodeDef &node, + FieldKind kind) { + auto ir_op_name = node.ir_op_name(ast.lang_name(), kind).value(); + auto ir_name = Symbol(absl::StrCat(ast.lang_name(), "ir")); + + auto visitor = Symbol("Visit") + node.name(); + if (kind == FIELD_KIND_LVAL) { + visitor += "Ref"; + } + + auto vars = WithVars({ + {"OpName", ir_op_name.ToPascalCase()}, + {"Name", (Symbol(ast.lang_name()) + node.name()).ToPascalCase()}, + {"name", kind == FIELD_KIND_ATTR ? "attr" : "op"}, + {"IrName", ir_name.ToPascalCase()}, + {"Visitor", visitor.ToPascalCase()}, + }); + + Println("absl::StatusOr>"); + Println("$IrName$ToAst::$Visitor$($OpName$ $name$) {"); + { + auto indent = WithIndent(); + for (const auto *field : node.aggregated_fields()) { + if (FieldIsArgument(field)) { + PrintField(ast, node, *field); + } else if (FieldIsRegion(field)) { + PrintRegion(ast, node, *field); + } + } + + // Call the constructor. + Print("return Create<$Name$>(\n"); + { + auto indent = WithIndent(4); + Print("$name$"); + + for (const FieldDef *field : node.aggregated_fields()) { + if (!FieldIsArgument(field) && !FieldIsRegion(field)) { + continue; + } + + auto vars = WithVars({ + {"field_name", field->name().ToCcVarName()}, + }); + Print(",\nstd::move($field_name$)"); + } + } + + Println(");"); + } + Println("}"); + Println(); +} + +void IrToAstSourcePrinter::PrintField(const AstDef &ast, const NodeDef &node, + const FieldDef &field) { + MaybeNull maybe_null = OptionalnessToMaybeNull(field.optionalness()); + + auto mlir_getter = field.name().ToMlirGetter(); + + std::string rhs; + switch (field.kind()) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: { + rhs = absl::StrCat("op.", mlir_getter, "Attr()"); + break; + } + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: { + rhs = absl::StrCat("op.", mlir_getter, "()"); + break; + } + case FIELD_KIND_STMT: { + LOG(FATAL) << "Unsupported FieldKind."; + } + } + + PrintNullableFromIr(ast, Action::kDef, field.type(), maybe_null, field.kind(), + /*lhs=*/field.name(), rhs, RhsKind::kFieldGetterResult); +} + +void IrToAstSourcePrinter::PrintRegion(const AstDef &ast, const NodeDef &node, + const FieldDef &field) { + MaybeNull maybe_null = OptionalnessToMaybeNull(field.optionalness()); + + auto vars = WithVars({ + {"lhs", field.name().ToCcVarName()}, + {"cc_type", field.type().CcType(maybe_null)}, + {"mlirGetter", field.name().ToMlirGetter()}, + }); + + auto print_from_region = [&](Action action) { + std::string rhs; + std::string rhs_getter; + RhsKind rhs_kind; + switch (field.kind()) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + LOG(FATAL) << "Unsupported FieldKind: " << field.kind(); + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: { + if (field.type().IsA()) { + rhs = (Symbol("mlir") + field.name() + "values").ToCcVarName(); + rhs_getter = "GetExprsRegionValues"; + rhs_kind = RhsKind::kListElement; + } else { + rhs = (Symbol("mlir") + field.name() + "value").ToCcVarName(); + rhs_getter = "GetExprRegionValue"; + rhs_kind = RhsKind::kListElement; + } + break; + } + case FIELD_KIND_STMT: { + if (field.type().IsA()) { + rhs = (Symbol("mlir") + field.name() + "block").ToCcVarName(); + rhs_getter = "GetStmtsRegionBlock"; + rhs_kind = RhsKind::kListElement; + } else { + rhs = (Symbol("mlir") + field.name() + "operation").ToCcVarName(); + rhs_getter = "GetStmtRegionOperation"; + rhs_kind = RhsKind::kFieldGetterResult; + } + break; + } + } + + auto vars = WithVars({ + {"rhs", rhs}, + {"RhsGetter", rhs_getter}, + }); + + Println( + "MALDOCA_ASSIGN_OR_RETURN" + "(auto $rhs$, $RhsGetter$(op.$mlirGetter$()));"); + + PrintFromIr(ast, action, field.type(), field.kind(), /*lhs=*/field.name(), + rhs, rhs_kind); + }; + + switch (maybe_null) { + case MaybeNull::kYes: { + Println("$cc_type$ $lhs$;"); + Println("if (!op.$mlirGetter$().empty()) {"); + { + auto indent = WithIndent(); + print_from_region(Action::kAssign); + } + Println("}"); + break; + } + case MaybeNull::kNo: { + print_from_region(Action::kDef); + } + } +} + +void IrToAstSourcePrinter::PrintNullableFromIr( + const AstDef &ast, Action action, const Type &type, MaybeNull maybe_null, + FieldKind kind, const Symbol &lhs, const std::string &rhs, + RhsKind rhs_kind) { + auto none_op = Symbol(absl::StrCat(ast.lang_name(), "ir")) + "NoneOp"; + auto vars = WithVars({ + {"type", type.CcType(maybe_null)}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs}, + {"NoneOp", none_op.ToPascalCase()}, + }); + + switch (maybe_null) { + case MaybeNull::kYes: { + switch (action) { + case Action::kDef: + Println("$type$ $lhs$;"); + ABSL_FALLTHROUGH_INTENDED; + + case Action::kAssign: { + // When a field is an mlir::Attribute, it can be null; but when a + // field is an mlir::Value, it cannot be null (MLIR verification would + // fail). Therefore, we need to define a specific NoneOp + // operation to represent null fields. + bool should_use_none_op = false; + switch (rhs_kind) { + case RhsKind::kOp: + case RhsKind::kFieldGetterResult: + break; + case RhsKind::kListElement: + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_STMT: + case FIELD_KIND_ATTR: + break; + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: + should_use_none_op = true; + } + break; + } + + if (should_use_none_op) { + Println("if (!llvm::isa<$NoneOp$>($rhs$.getDefiningOp())) {"); + } else { + Println("if ($rhs$ != nullptr) {"); + } + + { + auto indent = WithIndent(); + PrintFromIr(ast, Action::kAssign, type, kind, lhs, rhs, rhs_kind); + } + Println("}"); + break; + } + } + break; + } + case MaybeNull::kNo: + PrintFromIr(ast, action, type, kind, lhs, rhs, rhs_kind); + break; + } +} + +void IrToAstSourcePrinter::PrintFromIr(const AstDef &ast, Action action, + const Type &type, FieldKind kind, + const Symbol &lhs, + const std::string &rhs, + RhsKind rhs_kind) { + switch (type.kind()) { + case TypeKind::kBuiltin: { + const auto &builtin_type = static_cast(type); + PrintBuiltinFromIr(ast, action, builtin_type, lhs, rhs, rhs_kind); + break; + } + case TypeKind::kClass: { + const auto &class_type = static_cast(type); + PrintClassFromIr(ast, action, class_type, kind, lhs, rhs, rhs_kind); + break; + } + case TypeKind::kVariant: { + const auto &variant_type = static_cast(type); + PrintVariantFromIr(ast, action, variant_type, kind, lhs, rhs, rhs_kind); + break; + } + case TypeKind::kList: { + const auto &list_type = static_cast(type); + PrintListFromIr(ast, action, list_type, kind, lhs, rhs); + break; + } + case TypeKind::kEnum: + const auto &enum_type = static_cast(type); + PrintEnumFromIr(ast, action, enum_type, lhs, rhs); + break; + } +} + +void IrToAstSourcePrinter::PrintBuiltinFromIr(const AstDef &ast, Action action, + const BuiltinType &type, + const Symbol &lhs, + const std::string &rhs, + RhsKind rhs_kind) { + auto vars = WithVars({ + {"type", type.CcType()}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs}, + {"AttrName", type.CcMlirGetterType(FIELD_KIND_ATTR)}, + }); + + std::string converted_rhs; + switch (rhs_kind) { + case RhsKind::kOp: + case RhsKind::kFieldGetterResult: + converted_rhs = rhs; + break; + case RhsKind::kListElement: { + auto cast = UnIndentedSource(R"( + auto $lhs$_attr = llvm::dyn_cast<$AttrName$>($rhs$); + if ($lhs$_attr == nullptr) { + return absl::InvalidArgumentError("Invalid attribute."); + } + )"); + Println(cast); + + converted_rhs = absl::StrCat(lhs.ToCcVarName(), "_attr"); + break; + } + } + auto converted_rhs_var = WithVars({ + {"rhs", converted_rhs}, + }); + + switch (action) { + case Action::kDef: + Print("$type$ "); + ABSL_FALLTHROUGH_INTENDED; + case Action::kAssign: + Print("$lhs$ = "); + break; + } + + switch (type.builtin_kind()) { + case BuiltinTypeKind::kBool: { + Print("$rhs$.getValue()"); + break; + } + case BuiltinTypeKind::kInt64: { + Print("$rhs$.getValue().getInt()"); + break; + } + case BuiltinTypeKind::kString: { + Print("$rhs$.str()"); + break; + } + case BuiltinTypeKind::kDouble: { + Print("$rhs$.getValueAsDouble()"); + break; + } + } + + Println(";"); +} + +void IrToAstSourcePrinter::PrintClassFromIr(const AstDef &ast, Action action, + const ClassType &type, + FieldKind kind, const Symbol &lhs, + const std::string &rhs, + RhsKind rhs_kind) { + auto node_it = ast.nodes().find(type.name().ToPascalCase()); + CHECK(node_it != ast.nodes().end()); + auto op_name = node_it->second->ir_op_name(ast.lang_name(), kind); + + auto vars = WithVars({ + {"ClassName", type.name().ToPascalCase()}, + {"OpName", + op_name.has_value() ? op_name->ToPascalCase() : "mlir::Operation*"}, + {"lhs", lhs.ToCcVarName()}, + }); + + // Sometimes `rhs` is a mlir::Value, so we first need to call getDefiningOp(). + std::string converted_rhs; + switch (rhs_kind) { + case RhsKind::kOp: + converted_rhs = rhs; + break; + case RhsKind::kFieldGetterResult: + case RhsKind::kListElement: { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_ATTR: + case FIELD_KIND_STMT: + converted_rhs = rhs; + break; + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: + converted_rhs = absl::StrCat(rhs, ".getDefiningOp()"); + break; + } + break; + } + } + auto converted_rhs_var = WithVars({ + {"rhs", converted_rhs}, + }); + + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_ATTR: { + switch (rhs_kind) { + case RhsKind::kOp: + case RhsKind::kFieldGetterResult: + // $rhs$ is a Attr now. + break; + case RhsKind::kListElement: { + auto cast = UnIndentedSource(R"( + auto $lhs$_attr = llvm::dyn_cast<$OpName$>($rhs$); + if ($lhs$_attr == nullptr) { + return absl::InvalidArgumentError("Invalid attribute."); + } + )"); + Println(cast); + + converted_rhs = absl::StrCat(lhs.ToCcVarName(), "_attr"); + + break; + } + } + break; + } + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: { + switch (rhs_kind) { + case RhsKind::kOp: + // $rhs$ is a Op now. + break; + case RhsKind::kFieldGetterResult: + case RhsKind::kListElement: { + auto cast = UnIndentedSource(R"cc( + auto $lhs$_op = llvm::dyn_cast<$OpName$>($rhs$); + if ($lhs$_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected $OpName$, got ", + $rhs$->getName().getStringRef().str(), ".")); + } + )cc"); + Println(cast); + + converted_rhs = absl::StrCat(lhs.ToCcVarName(), "_op"); + + break; + } + } + break; + } + + case FIELD_KIND_STMT: { + switch (rhs_kind) { + case RhsKind::kOp: + break; + case RhsKind::kFieldGetterResult: { + // $rhs$ is an mlir::Operation* now. + auto cast = UnIndentedSource(R"cc( + auto $lhs$_op = llvm::dyn_cast<$OpName$>($rhs$); + if ($lhs$_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected $OpName$, got ", + $rhs$->getName().getStringRef().str(), ".")); + } + )cc"); + Println(cast); + + converted_rhs = absl::StrCat(lhs.ToCcVarName(), "_op"); + + break; + } + case RhsKind::kListElement: { + auto cast = UnIndentedSource(R"( + auto $lhs$_op = llvm::dyn_cast<$OpName$>($rhs$); + if ($lhs$_op == nullptr) { + continue; + } + )"); + Println(cast); + + converted_rhs = absl::StrCat(lhs.ToCcVarName(), "_op"); + + break; + } + } + break; + } + } + + auto further_converted_rhs_var = WithVars({ + {"rhs", converted_rhs}, + }); + + Print("MALDOCA_ASSIGN_OR_RETURN("); + + switch (action) { + case Action::kDef: { + auto vars = WithVars({ + {"type", type.CcType()}, + }); + Print("$type$ "); + ABSL_FALLTHROUGH_INTENDED; + } + case Action::kAssign: + Print("$lhs$, "); + break; + } + + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_ATTR: + Println("Visit$ClassName$Attr($rhs$));"); + break; + case FIELD_KIND_RVAL: + case FIELD_KIND_STMT: { + Println("Visit$ClassName$($rhs$));"); + break; + } + case FIELD_KIND_LVAL: { + Println("Visit$ClassName$Ref($rhs$));"); + break; + } + } +} + +void IrToAstSourcePrinter::PrintEnumFromIr(const AstDef &ast, Action action, + const EnumType &type, + const Symbol &lhs, + const std::string &rhs) { + auto enum_name = Symbol(ast.lang_name()) + type.name(); + + absl::btree_map vars = { + {"Type", enum_name.ToPascalCase()}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", rhs}, + }; + + switch (action) { + case Action::kDef: + Println(vars, + "MALDOCA_ASSIGN_OR_RETURN" + "($Type$ $lhs$, StringTo$Type$($rhs$.str()));"); + break; + case Action::kAssign: + Println(vars, + "MALDOCA_ASSIGN_OR_RETURN(lhs, StringTo$Type$($rhs$.str()));"); + break; + } +} + +void IrToAstSourcePrinter::PrintVariantFromIr(const AstDef &ast, Action action, + const VariantType &type, + FieldKind kind, const Symbol &lhs, + const std::string &rhs, + RhsKind rhs_kind) { + auto mlir_lhs = (Symbol("mlir") + lhs).ToCcVarName(); + auto vars = WithVars({ + {"type", type.CcType()}, + {"lhs", lhs.ToCcVarName()}, + {"mlir_lhs", mlir_lhs}, + }); + + switch (action) { + case Action::kDef: + Println("$type$ $lhs$;"); + break; + case Action::kAssign: + break; + } + + std::string converted_rhs; + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_ATTR: + case FIELD_KIND_STMT: + converted_rhs = rhs; + break; + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: + converted_rhs = absl::StrCat(rhs, ".getDefiningOp()"); + break; + } + auto rhs_var = WithVars({ + {"rhs", converted_rhs}, + }); + + IfStmtPrinter if_stmt(this); + for (const auto &scalar_type : type.types()) { + if_stmt.PrintCase({ + [&] { + if (scalar_type->IsA()) { + const auto &class_type = + static_cast(*scalar_type); + + auto node_it = ast.nodes().find(class_type.name().ToPascalCase()); + CHECK(node_it != ast.nodes().end()); + auto op_name = node_it->second->ir_op_name(ast.lang_name(), kind); + if (!op_name.has_value()) { + // There is no base op or attribute. Fallback to mlir::Attribute + // or mlir::Operation*. + + Print("auto $mlir_lhs$ = $rhs$"); + } + auto vars = WithVars({ + {"Op", op_name.has_value() ? op_name->ToPascalCase() + : "mlir::Operation*"}, + }); + + Print("auto $mlir_lhs$ = llvm::dyn_cast<$Op$>($rhs$)"); + } else if (scalar_type->IsA()) { + const auto &builtin_type = + static_cast(*scalar_type); + + auto attr_name = builtin_type.CcMlirGetterType(kind); + auto vars = WithVars({ + {"Attr", attr_name}, + }); + + Print("auto $mlir_lhs$ = llvm::dyn_cast<$Attr$>($rhs$)"); + } + }, + [&] { + // Body + PrintFromIr(ast, Action::kAssign, *scalar_type, kind, lhs, mlir_lhs, + RhsKind::kOp); + }, + }); + } + + switch (kind) { + case FIELD_KIND_STMT: { + if (rhs_kind == RhsKind::kListElement) { + const auto handle_invalid_type = UnIndentedSource(R"( + else { + continue; + } + )"); + Println(handle_invalid_type); + break; + } + + ABSL_FALLTHROUGH_INTENDED; + } + case FIELD_KIND_UNSPECIFIED: + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: + case FIELD_KIND_ATTR: { + const auto handle_invalid_type = UnIndentedSource(R"( + else { + return absl::InvalidArgumentError("$rhs$ has invalid type."); + } + )"); + Println(handle_invalid_type); + break; + } + } +} + +void IrToAstSourcePrinter::PrintListFromIr(const AstDef &ast, Action action, + const ListType &type, FieldKind kind, + const Symbol &lhs, + const std::string &rhs) { + const Symbol lhs_element = lhs + "element"; + const Symbol mlir_lhs_element = Symbol("mlir") + lhs + "element_unchecked"; + + // Even if we are asked to assign to `lhs`, since the type of `lhs` is not + // exactly type.CcType(), we need to define a new variable and assign it + // to `lhs` at the end. + const Symbol lhs_defined = [&] { + switch (action) { + case Action::kDef: + return lhs; + case Action::kAssign: + return lhs + "value"; + } + }(); + + std::string converted_rhs = rhs; + std::string rhs_element_type; + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "FieldKind unspecified."; + case FIELD_KIND_ATTR: { + // rhs: mlir::ArrayAttr; + // rhs.getValue(): llvm::ArrayRef + converted_rhs = absl::StrCat(rhs, ".getValue()"); + rhs_element_type = "mlir::Attribute"; + break; + } + case FIELD_KIND_RVAL: + case FIELD_KIND_LVAL: { + // rhs: mlir::OperandRange + converted_rhs = rhs; + rhs_element_type = "mlir::Value"; + break; + } + case FIELD_KIND_STMT: { + // rhs: mlir::Block* + converted_rhs = absl::StrCat("*", rhs); + rhs_element_type = "mlir::Operation&"; + break; + } + } + + auto vars = WithVars({ + {"cc_type", type.CcType()}, + {"lhs", lhs.ToCcVarName()}, + {"rhs", converted_rhs}, + {"rhs_element_type", rhs_element_type}, + {"lhs_defined", lhs_defined.ToCcVarName()}, + {"lhs_element", lhs_element.ToCcVarName()}, + {"mlir_lhs_element", mlir_lhs_element.ToCcVarName()}, + }); + + Println("$cc_type$ $lhs_defined$;"); + Println("for ($rhs_element_type$ $mlir_lhs_element$ : $rhs$) {"); + { + auto indent = WithIndent(); + PrintNullableFromIr(ast, Action::kDef, type.element_type(), + type.element_maybe_null(), kind, lhs_element, + mlir_lhs_element.ToCcVarName(), RhsKind::kListElement); + Println("$lhs_defined$.push_back(std::move($lhs_element$));"); + } + Println("}"); + + switch (action) { + case Action::kDef: + break; + case Action::kAssign: + Println("$lhs$ = std::move($lhs_defined$);"); + break; + } +} + +std::string PrintIrToAstSource(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path, + absl::string_view ir_path) { + std::string str; + { + google::protobuf::io::StringOutputStream os(&str); + IrToAstSourcePrinter printer(&os); + printer.PrintAst(ast, cc_namespace, ast_path, ir_path); + } + + return str; +} + +} // namespace maldoca diff --git a/maldoca/astgen/ast_gen.h b/maldoca/astgen/ast_gen.h new file mode 100644 index 0000000..7c8e7ed --- /dev/null +++ b/maldoca/astgen/ast_gen.h @@ -0,0 +1,956 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_AST_GEN_H_ +#define MALDOCA_ASTGEN_AST_GEN_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "maldoca/astgen/ast_def.h" +#include "maldoca/astgen/symbol.h" +#include "maldoca/astgen/type.h" +#include "google/protobuf/io/printer.h" +#include "google/protobuf/io/zero_copy_stream.h" + +namespace maldoca { + +class AstGenPrinterBase : public google::protobuf::io::Printer { + public: + explicit AstGenPrinterBase(google::protobuf::io::ZeroCopyOutputStream *os) + : google::protobuf::io::Printer(os, /*variable_delimiter=*/'$') {} + + template + void Println(Args &&...args) { + Print(std::forward(args)...); + Print("\n"); + } + + void Println() { Println(""); } +}; + +// Printer of the TypeScript interface definition for the AST. +// +// Format: +// +// interface ObjectMember <: Node { +// key: Expression; +// computed: boolean; +// decorators?: [ Decorator ]; +// } +class TsInterfacePrinter : AstGenPrinterBase { + public: + explicit TsInterfacePrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : AstGenPrinterBase(os) {} + + // Prints the "ast_ts_interface.generated" file. + // + // See test cases in test/ for examples. + void PrintAst(const AstDef &ast); + + // Prints an enum definition. + // + // See test cases in test/ for examples. + void PrintEnum(const EnumDef &enum_def, absl::string_view lang_name); + + // Prints the class declaration for a node. + // + // See test cases in test/ for examples. + void PrintNode(const NodeDef &node); + + // Prints the definition of a field. + // + // Format: + // : + // ?: + // + // - fieldName: Printed as camelCase. + // - js_type: See `Type::JsType()`. + // + // Example: + // right: Expression + // param?: Pattern + void PrintFieldDef(const FieldDef &field); +}; + +std::string PrintTsInterface(const AstDef &ast); + +// Common functions for printing C++ code. +class CcPrinterBase : public AstGenPrinterBase { + public: + explicit CcPrinterBase(google::protobuf::io::ZeroCopyOutputStream *os) + : AstGenPrinterBase(os) {} + + // Print Apache license comment. + void PrintLicense(); + + // Example: + // + // Input: + // cc_namespace == "maldoca::astgen" + // + // Output: + // ``` + // namespace maldoca { + // namespace astgen { + // ``` + void PrintEnterNamespace(absl::string_view cc_namespace); + + // Example: + // + // Input: + // cc_namespace == "maldoca::astgen" + // + // Output: + // ``` + // } // namespace astgen + // } // namespace maldoca + // ``` + void PrintExitNamespace(absl::string_view cc_namespace); + + // Example: + // + // Input: + // header_path == "maldoca/astgen/test/lambda/ast.h" + // + // Output: + // ``` + // #ifndef MALDOCA_ASTGEN_TEST_LAMBDA_AST_H_ + // #define MALDOCA_ASTGEN_TEST_LAMBDA_AST_H_ + // ``` + void PrintEnterHeaderGuard(absl::string_view header_path); + + // Example: + // + // Input: + // header_path == "maldoca/astgen/test/lambda/ast.h" + // + // Output: + // ``` + // #endif // MALDOCA_ASTGEN_TEST_LAMBDA_AST_H_ + // ``` + void PrintExitHeaderGuard(absl::string_view header_path); + + // Example: + // + // Input: + // header_path == "maldoca/astgen/test/lambda/ast.h" + // + // Output: + // ``` + // #include "maldoca/astgen/test/lambda/ast.h" + // ``` + void PrintIncludeHeader(absl::string_view header_path); + + // Prints headers in alphabetical order by sorting a copy of the header paths. + void PrintIncludeHeaders(std::vector header_paths); + + // Example: + // + // Input: + // title == "BinaryExpression" + // + // Output: + // ``` + // // ======================================================================== + // // BinaryExpression + // // ======================================================================== + // ``` + void PrintTitle(absl::string_view title); + + // Output: + // // ======================================================================== + // // STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. + // // ======================================================================== + void PrintCodeGenerationWarning(); + + // Some convenient wrappers for printing C++ types. + std::string CcType(const FieldDef &field) const { + return field.type().CcType(field.optionalness()); + } + + std::string CcMutableGetterType(const FieldDef &field) const { + return field.type().CcMutableGetterType(field.optionalness()); + } + + std::string CcConstGetterType(const FieldDef &field) const { + return field.type().CcConstGetterType(field.optionalness()); + } +}; + +// Printer of the C++ header for the AST. +class AstHeaderPrinter : public CcPrinterBase { + public: + explicit AstHeaderPrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : CcPrinterBase(os) {} + + // Prints the "ast.generated.h" header file. + // + // - cc_namespace: The C++ namespace for all the AST node classes. + // Example: "maldoca::astgen". + // + // - ast_path: The directory for the AST code. + // "ast.generated.h" is in that directory. + // This is used to generate the header guard. + // + // See test cases in test/ for examples. + void PrintAst(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + + // Prints the enum definition and the prototypes of string conversion + // functions. + // + // Example: + // enum UnaryOperator { + // kMinus, + // ... + // }; + // + // absl::string_view UnaryOperatorToString(UnaryOperator unary_operator); + // absl::StatusOr StringToUnaryOperator(absl::string_view s); + void PrintEnum(const EnumDef &enum_def, absl::string_view lang_name); + + // Prints the class declaration for a node. + // + // See test cases in test/ for examples. + void PrintNode(const NodeDef &node, absl::string_view lang_name); + + // Prints the constructor of a node class. + // + // Example: + // explicit Variable(std::string identifier) + // : Expression(), identifier_(std::move(identifier)) {} + void PrintConstructor(const NodeDef &node, absl::string_view lang_name); + + // Prints the getter and setter declarations for a field. + // + // Format: + // (); + // () const; + // void set_( ); + // + // - cc_mutable_getter_type: See `Type::CcMutableGetterType()`. + // - cc_const_getter_type: See `Type::CcConstGetterType()`. + // - cc_type: See `Type::CcType()`. + // + // Example: + // Expression* right(); + // const Expression* right() const; + // void set_right(std::unique_ptr right); + void PrintGetterSetterDeclarations(const FieldDef &field, + absl::string_view lang_name); + + // Prints a member variable declaration. + // + // Format: + // _; + // + // - cc_type: The C++ value type. See `Type::CcType()`. + // - field_name_: We print the name in snake_case and add a '_'. + // + // Example: + // std::unique_ptr right_; + void PrintMemberVariable(const FieldDef &field, absl::string_view lang_name); + + // Format: + // static absl::StatusOr<> + // GetFromJson(const nlohmann::json& json); + // + // Example: + // static absl::StatusOr> + // GetRightFromJson(const nlohmann::json& json); + void PrintGetFromJson(const FieldDef &field, absl::string_view lang_name); +}; + +// Prints the "ast.generated.h" header file. +// +// - cc_namespace: The C++ namespace for all the AST node classes. +// Example: "maldoca::astgen". +// +// - ast_path: The directory for the AST code. +// "ast.generated.h" is in that directory. +// This is used to generate the header guard. +std::string PrintAstHeader(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + +// Printer of the C++ source for the AST. +class AstSourcePrinter : public CcPrinterBase { + public: + explicit AstSourcePrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : CcPrinterBase(os) {} + + // Prints the "ast.generated.cc" file, which includes the definitions of + // getters and setters of all the AST node classes. + // + // - cc_namespace: A namespace separated by "::". + // This is used to print C++ namespaces. + // + // - ast_path: The directory for the AST code. + // "ast.generated.h" is in that directory. + // This is used to print the #include. + void PrintAst(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + + private: + void PrintConstructor(const NodeDef &node, absl::string_view lang_name); + + // Prints the string conversion functions. + // + // Example: + // + // absl::string_view UnaryOperatorToString(UnaryOperator unary_operator) { + // ... + // } + // + // absl::StatusOr StringToUnaryOperator(absl::string_view s) { + // ... + // } + void PrintEnum(const EnumDef &enum_def, absl::string_view lang_name); + + // Prints the getters and setters of one AST node class. + void PrintNode(const NodeDef &node, absl::string_view lang_name); + + // Prints the C++ code that returns a value that's compatible with the types + // `type.CcMutableGetterType()` and `type.CcConstGetterType()`. + // + // `cc_expr` is an lvalue expression of the type `type.CcType()`. + void PrintGetterBody(const std::string &cc_expr, const Type &type); + + // Prints the C++ code that returns a value that's compatible with the types + // `type.CcMutableGetterType(is_optional)` and + // `type.CcConstGetterType(is_optional)`. + // + // `cc_expr` is an lvalue expression of the type `type.CcType()`. + void PrintGetterBody(const Symbol &field_name, const Type &type, + bool is_optional); + + // Prints the C++ code that sets one field. + // + // `field_name` is an lvalue expression that has the type + // `type.CcType(is_optional)`. We need to set the field `field_name_`. + void PrintSetterBody(const Symbol &field_name, const Type &type, + bool is_optional); +}; + +std::string PrintAstSource(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + +enum class DefOrAssign { + // Define a variable of the exact same type. + kDef, + + // Assign to an existing variable of a compatible type. + kAssign, +}; + +// Printer of the C++ Serialize() function for the AST. +class AstSerializePrinter : public CcPrinterBase { + public: + explicit AstSerializePrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : CcPrinterBase(os) {} + + void PrintAst(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + + private: + // Print*Serialize() + // + // Prints either: + // - An assignment " = ConvertSerialize();", or + // - A variable definition "nlohmann::json = ConvertSerialize();" + // + // - lhs: If printing an assignment, an lvalue expression of type + // nlohmann::json; if printing a variable definition, the name of that + // variable. + // - rhs: An expression of type `type.CcType()`. + void PrintBuiltinSerialize(const BuiltinType &type, + const std::string &lhs, + const std::string &rhs); + + void PrintEnumSerialize(const EnumType &type, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name); + + void PrintClassSerialize(const ClassType &type, + const std::string &lhs, + const std::string &rhs); + + void PrintVariantSerialize(const VariantType &variant_type, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name); + + void PrintListSerialize(const ListType &list_type, + const std::string &lhs, + const std::string &rhs, + absl::string_view lang_name); + + void PrintSerialize(const Type &type, const std::string &lhs, + const std::string &rhs, absl::string_view lang_name); + + void PrintNullableToJson(const Type &type, MaybeNull maybe_null, + const std::string &lhs, const std::string &rhs, + absl::string_view lang_name); + + void PrintSerializeFieldsFunction(const NodeDef &node, + absl::string_view lang_name); + + void PrintSerializeFunction(const NodeDef &node, + absl::string_view lang_name); + + void PrintSerializeFunctionOverload(const NodeDef &node, + absl::string_view lang_name); +}; + +// Prints the "ast_to_json.generated.cc" source file. +// +// - cc_namespace: The C++ namespace for all the AST node classes. +// Example: "maldoca::astgen". +// +// - ast_path: The directory for the AST code. +// "ast.generated.h" is in that directory. +// This is used to print the #include. +std::string PrintAstToJson(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + +class AstFromJsonPrinter : public CcPrinterBase { + public: + explicit AstFromJsonPrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : CcPrinterBase(os) {} + + void PrintAst(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + + void PrintTypeChecker(const NodeDef &node); + + void PrintBuiltinJsonTypeCheck(const BuiltinType &type, const Symbol &rhs); + + void PrintClassJsonTypeCheck(const ClassType &class_type, const Symbol &rhs); + + // =========================================================================== + // Print*FromJson() + // =========================================================================== + // + // Prints the code that converts a nlohmann::json variable `rhs` to a C++ + // value of the type `type.CcType()`. + // + // `rhs` is guaranteed to be non-null. + // However, the type of `rhs` is unknown and might not match `type`. + // We need to print type-checking code before the actual conversion. + // + // Action + // ====== + // + // - Def: + // + // Defines a variable with name `lhs` and type `type.CcType()`. + // Stores the conversion result. + // + // - Assign: + // + // Assigns the conversion result to `lhs`, which is an existing variable. + // The type of `lhs` is not `type.CcType()`, but you can assign a value of + // `type.CcType()` to it. + // + // This is fine with scalar types. + // Consider the following example: + // nlohmann::json rhs = ...; + // std::optional lhs = ...; + // + // // Generated code starts here: + // // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // // ... Checks that `rhs` is a JSON array ... + // lhs = rhs.get(); // This assignment is okay. + // + // However, this is troublesome with list type. + // Consider the following example: + // nlohmann::json rhs = ...; + // std::optional> lhs = ...; + // + // // Generated code starts here: + // // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // // ... Checks that `rhs` is a JSON array ... + // for (const nlohmann::json& rhs_element : rhs) { + // std::string lhs_element = ...; // Convert from rhs_element. + // // Oh awkward - we can't do "lhs.push_back(...)". + // } + // + // Therefore, for lists, we need to create a new variable `lhs_value` and + // assign it to `lhs` in the end. Like this: + // nlohmann::json rhs = ...; + // std::optional> lhs = ...; + // + // // Generated code starts here: + // // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // // ... Checks that `rhs` is a JSON array ... + // std::vector lhs_value; + // for (const nlohmann::json& rhs_element : rhs) { + // std::string lhs_element = ...; // Convert from rhs_element. + // lhs_value.push_back(std::move(lhs_element)); + // } + // lhs = std::move(lhs_value); + // + // - Return: + // + // Returns the conversion result. + enum class Action { + kDef, + kAssign, + kReturn, + }; + + // CheckJsonType - Whether to print the JSON type check code. + // + // This is a special case for BuiltinType. + // In all other cases we always print the type check. + enum class CheckJsonType { + kYes, + kNo, + }; + + void PrintBuiltinFromJson(Action action, CheckJsonType check_json_type, + const BuiltinType &builtin_type, const Symbol &lhs, + const Symbol &rhs); + + void PrintEnumFromJson(Action action, const EnumType &enum_type, + const Symbol &lhs, const Symbol &rhs, + absl::string_view lang_name); + + void PrintClassFromJson(Action action, const ClassType &class_type, + const Symbol &lhs, const Symbol &rhs, + absl::string_view lang_name); + + void PrintVariantFromJson(Action action, const VariantType &variant_type, + const Symbol &lhs, const Symbol &rhs, + absl::string_view lang_name); + + void PrintListFromJson(Action action, const ListType &list_type, + const Symbol &lhs, const Symbol &rhs, + absl::string_view lang_name); + + void PrintFromJson(Action action, const Type &type, const Symbol &lhs, + const Symbol &rhs, absl::string_view lang_name); + + // Same as PrintFromJson, but: + // + // - The nullness of `rhs` is not guaranteed. We need to print the nullness + // check code. + // + // - Instead of `type.CcType()`, the output type is `type.CcType(maybe_null)`. + void PrintNullableFromJson(Action action, const Type &type, + MaybeNull maybe_null, const Symbol &lhs, + const Symbol &rhs, absl::string_view lang_name); + + void PrintGetFieldFunction(const std::string &node_name, + const FieldDef &field, + absl::string_view lang_name); + + void PrintFromJsonFunction(const NodeDef &node, absl::string_view lang_name); +}; + +std::string PrintAstFromJson(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path); + +class IrTableGenPrinter : public CcPrinterBase { + public: + explicit IrTableGenPrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : CcPrinterBase(os) {} + + void PrintAst(const AstDef &ast, absl::string_view ir_path); + + // Example: + // + // def JsirWithStatementOp : Jsir_Op< + // "with_statement", [ + // JsirStatementOpInterfaceTraits + // ]> { + // let arguments = (ins + // AnyType: $object + // ); + // + // let regions = (region + // AnyRegion: $body + // ); + // } + void PrintNode(const AstDef &ast, const NodeDef &node, FieldKind kind); + + // Prints an argument for an op in MLIR ODS. + // + // Format: + // + // : $ + // + // See Typd::TdType() for what the MLIR ODS type is for each Type. + // + // Example: + // + // AnyType: $object + void PrintArgument(const AstDef &ast, const NodeDef &node, + const FieldDef &field); + + // Prints a region in an op in MLIR ODS. + // + // Format: + // + // AnyRegion: $ + // + // Example: + // + // AnyRegion: $body + void PrintRegion(const AstDef &ast, const NodeDef &node, + const FieldDef &field); +}; + +// Prints the "ir_ops.generated.td" TableGen file. +// +// - ir_path: The directory for the IR code. +// +// The following files are in that directory: +// - "ir_dialect.td" +// - "ir_ops.generated.td" +// - "interfaces.td" +// +// This is used to print the includes and header guards. +std::string PrintIrTableGen(const AstDef &ast, absl::string_view ir_path); + +class AstToIrSourcePrinter : public CcPrinterBase { + public: + explicit AstToIrSourcePrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : CcPrinterBase(os) {} + + // Action: What to do with the converted IR value/attribute. + // + // - Def: Define a variable. + // - Assign: Assign the value/attribute to an existing variable. + // - Create: Just create the value/attribute and ignore it. + // + // See comments for Print*ToIr for more details. + enum class Action { + kDef, + kAssign, + kCreate, + }; + + // Whether a C++ expression refers to a "reference" or a "value". + // + // Consider the following AST node: + // class CallExpression : ... { + // public: + // const Expression *func() const; + // const std::vector> *args() const; + // }; + // + // - The type of func() is "const Expression *". + // We consider this a "reference". + // + // - The type of args()[0] is "std::unique_ptr &". + // We consider this a "value". + // + // However, in the ASTGen type system, we refer them both as + // ClassType{"Expression"}. Therefore, we need this additional enum to make + // the distinction. + // + // If a function takes a "reference" but we have a "value", we need to call + // ".get()" to turn it into a "reference". + enum RefOrVal { + kRef, + kVal, + }; + + void PrintAst(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path, absl::string_view ir_path); + + // Prints the Visit() function. + void PrintNonLeafNode(const AstDef &ast, const NodeDef &node, FieldKind kind); + + void PrintLeafNode(const AstDef &ast, const NodeDef &node, FieldKind kind); + + // =========================================================================== + // Print*ToIr + // =========================================================================== + // + // Prints the conversion of a C++ expression that represents a field from the + // AST to the corresponding MLIR value/attribute. The result is later used to + // build MLIR ops. + // - rhs: The original C++ expression that represents a field from the AST. + // + // - lhs: The name of the variable to assign to or create, after the + // conversion. + // + // - action: + // - kDef: + // mlir::Value = Convert(); + // - kAssign: + // = Convert(); + // - kCreate: + // Convert(); + // + // - type: The type of the AST field. + // + // - ref_or_val: See comments for RefOrVal. + // + // - kind: Kind of the field. See comments for FieldKind. + // If kind == FIELD_KIND_LVAL, then we need to append "Ref" to the op name. + void PrintBuiltinToIr(const AstDef &ast, Action action, + const BuiltinType &type, const Symbol &lhs, + const std::string &rhs); + + void PrintClassToIr(const AstDef &ast, Action action, const ClassType &type, + FieldKind kind, const Symbol &lhs, + const std::string &rhs); + + void PrintClassToIr(const AstDef &ast, Action action, const ClassType &type, + RefOrVal ref_or_val, FieldKind kind, const Symbol &lhs, + const std::string &rhs); + + void PrintEnumToIr(const AstDef &ast, Action action, const EnumType &type, + const Symbol &lhs, const std::string &rhs); + + void PrintVariantToIr(const AstDef &ast, Action action, + const VariantType &type, RefOrVal ref_or_val, + FieldKind kind, const Symbol &lhs, + const std::string &rhs); + + void PrintListToIr(const AstDef &ast, Action action, const ListType &type, + FieldKind kind, const Symbol &lhs, const std::string &rhs); + + void PrintToIr(const AstDef &ast, Action action, const Type &type, + RefOrVal ref_or_val, FieldKind kind, const Symbol &lhs, + const std::string &rhs); + + void PrintNullableToIr(const AstDef &ast, Action action, const Type &type, + MaybeNull maybe_null, RefOrVal ref_or_val, + FieldKind kind, const Symbol &lhs, + const std::string &rhs); + + // Prints the code that converts an AST field to an MLIR value/attribute and + // stores the result in a new variable. + // + // Format: + // + // mlir_ = Visit(node->()); + // + // Example: + // + // mlir::Value mlir_object = VisitExpression(node->object()); + void PrintField(const AstDef &ast, const NodeDef &node, + const FieldDef &field); + + // Prints the code that converts an AST field to a region. The region has been + // created and the code just populates blocks and ops in it. + // + // Format: + // + // mlir::Region &mlir__region = op.(); + // AppendNewBlockAndPopulate(mlir__region, [&] { + // foo() into elements in the region.> + // }); + // + // Example: + // + // mlir::Region &mlir_body_region = op.body(); + // AppendNewBlockAndPopulate(mlir_body_region, [&] { + // for (const auto &element : *node->body()) { + // VisitStatement(element.get()); + // } + // }); + void PrintRegion(const AstDef &ast, const NodeDef &node, + const FieldDef &field); +}; + +// Prints the "ast_toir.generated.cc" file. +// +// - cc_namespace: The namespace where all IR op classes live. +// +// - ast_path: The directory for the AST code. +// +// "ast.generated.h" is in that directory. +// +// This is used to print the #inclueds. +// +// - ir_path: The directory for the IR code. +// +// The following files are in that directory: +// - "ir_dialect.td" +// - "ir_ops.generated.td" +// - "interfaces.td" +// - "conversion/ast_to_ir.h" +// - "conversion/ast_to_ir.generated.cc" +// +// This is used to print the #includes and header guards. +std::string PrintAstToIrSource(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path, + absl::string_view ir_path); + +class IrToAstSourcePrinter : public CcPrinterBase { + public: + explicit IrToAstSourcePrinter(google::protobuf::io::ZeroCopyOutputStream *os) + : CcPrinterBase(os) {} + + // Action: What to do with the converted AST field. + // + // - Def: Define a variable. + // - Assign: Assign the value to an existing variable. + // + // See comments for Print*FromIr for more details. + enum class Action { + kDef, + kAssign, + }; + + enum class RhsKind { + // rhs could be: + // - mlir::Value + // - mlir::ValueRange + // - mlir::Block & + // - specific attribute. + // Requesting type checking and error returning on type mismatch. + kFieldGetterResult, + + // rhs could be: + // - mlir::Operation* (for FIELD_KIND_STMT) + // - mlir::Value (for FIELD_KIND_LVAL, FIELD_KIND_RVAL) + // - mlir::Attribute + // Requesting type checking and error returning on type mismatch. + kListElement, + + // rhs is specific Op. + // llvm::cast(op.field_name().getDefiningOp()) + kOp, + }; + + void PrintAst(const AstDef &ast, absl::string_view cc_namespace, + absl::string_view ast_path, absl::string_view ir_path); + + // Prints the Visit() function. + void PrintNonLeafNode(const AstDef &ast, const NodeDef &node, FieldKind kind); + + void PrintLeafNode(const AstDef &ast, const NodeDef &node, FieldKind kind); + + void PrintField(const AstDef &ast, const NodeDef &node, + const FieldDef &field); + + // Prints the code that, given an operation `op`: + // (1) Fetches the region `op.getFieldName()`; + // (2) Extracts MLIR value/operation(s) from the region, so that we get: + // - mlir::Value + // - mlir::Operation * + // - mlir::ValueRange + // - mlir::Block * + // (4) Converts the above to a field for an AST node. + // + // Example output: + // ``` + // MALDOCA_ASSIGN_OR_RETURN(auto mlir_field_name_value, + // GetExprRegionValue(op.getFieldName())); + // auto field_name_op = + // llvm::dyn_cast(mlir_field_name_value.getDefiningOp()); + // if (field_name_op == nullptr) { + // return absl::InvalidArgumentError("Invalid op."); + // } + // MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr expr, + // VisitSome(field_name_op)); + // ``` + void PrintRegion(const AstDef &ast, const NodeDef &node, + const FieldDef &field); + + // =========================================================================== + // Print*FromIr + // =========================================================================== + // + // Prints the conversion of a C++ expression that represents MLIR + // value/attribute to the corresponding field in the AST. The result is later + // used to build AST nodes. + // + // - rhs: The original C++ expression that represents an MLIR value/attribute. + // + // - lhs: The name of the variable to assign to or create, after the + // conversion. + // + // - action: + // - kDef: + // = Convert(); + // - kAssign: + // = Convert(); + // + // - type: The type of the AST field. + // + // - rhs_kind: See comments for RhsKind. + // + // - kind: Kind of the field. See comments for FieldKind. + // If kind == FIELD_KIND_LVAL, then we need to append "Ref" to the op name. + void PrintNullableFromIr(const AstDef &ast, Action action, const Type &type, + MaybeNull maybe_null, FieldKind kind, + const Symbol &lhs, const std::string &rhs, + RhsKind rhs_kind); + + void PrintFromIr(const AstDef &ast, Action action, const Type &type, + FieldKind kind, const Symbol &lhs, const std::string &rhs, + RhsKind rhs_kind); + + void PrintBuiltinFromIr(const AstDef &ast, Action action, + const BuiltinType &type, const Symbol &lhs, + const std::string &rhs, RhsKind rhs_kind); + + void PrintClassFromIr(const AstDef &ast, Action action, const ClassType &type, + FieldKind kind, const Symbol &lhs, + const std::string &rhs, RhsKind rhs_kind); + + void PrintEnumFromIr(const AstDef &ast, Action action, const EnumType &type, + const Symbol &lhs, const std::string &rhs); + + void PrintVariantFromIr(const AstDef &ast, Action action, + const VariantType &type, FieldKind kind, + const Symbol &lhs, const std::string &rhs, + RhsKind rhs_kind); + + void PrintListFromIr(const AstDef &ast, Action action, const ListType &type, + FieldKind kind, const Symbol &lhs, + const std::string &rhs); + + private: + google::protobuf::io::ZeroCopyOutputStream *os_; +}; + +// Prints the "ir_to_ast.generated.cc" file. +// +// - cc_namespace: The namespace where all IR op classes live. +// +// - ast_path: The directory for the AST code. +// +// "ast.generated.h" is in that directory. +// +// This is used to print the #includes. +// +// - ir_path: The directory for the IR code. +// +// The following files are in that directory: +// - "ir_dialect.td" +// - "ir_ops.generated.td" +// - "interfaces.td" +// - "conversion/ir_to_ast.h" +// - "conversion/ir_to_ast.generated.cc" +// +// This is used to print the #includes and header guards. +std::string PrintIrToAstSource(const AstDef &ast, + absl::string_view cc_namespace, + absl::string_view ast_path, + absl::string_view ir_path); + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_AST_GEN_H_ diff --git a/maldoca/astgen/ast_gen_main.cc b/maldoca/astgen/ast_gen_main.cc new file mode 100644 index 0000000..b030da6 --- /dev/null +++ b/maldoca/astgen/ast_gen_main.cc @@ -0,0 +1,125 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + bazel build //maldoca/astgen:ast_gen_main + + ./bazel-bin/maldoca/astgen/ast_gen_main \ + --ast_def_path="maldoca/js/ast/ast_def.textproto" \ + --cc_namespace="maldoca" \ + --ast_path="maldoca/js/ast" \ + --ir_path="maldoca/js/ir" + */ + +#include +#include +#include "absl/flags/flag.h" +#include "absl/strings/str_cat.h" +#include "maldoca/astgen/ast_def.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/ast_gen.h" +#include "maldoca/base/filesystem.h" +#include "maldoca/base/path.h" +#include "maldoca/base/status_macros.h" + +ABSL_FLAG(std::string, ast_def_path, "", + "The path to the ast_def.textproto file."); + +ABSL_FLAG(std::string, cc_namespace, "", + "The C++ namespace for the AST classes in C++."); + +ABSL_FLAG(std::string, ast_path, "", "The directory for the AST code in C++."); + +ABSL_FLAG(std::string, ir_path, "", + "The directory for the IR code in TableGen and C++."); + +namespace maldoca { +namespace { + +absl::Status AstGenMain() { + auto ast_def_path = absl::GetFlag(FLAGS_ast_def_path); + auto cc_namespace = absl::GetFlag(FLAGS_cc_namespace); + auto ast_path = absl::GetFlag(FLAGS_ast_path); + auto ir_path = absl::GetFlag(FLAGS_ir_path); + + AstDefPb ast_def_pb; + MALDOCA_RETURN_IF_ERROR(ParseTextProtoFile(ast_def_path, &ast_def_pb)); + MALDOCA_ASSIGN_OR_RETURN(AstDef ast_def, AstDef::FromProto(ast_def_pb)); + + std::string ast_hdr = PrintAstHeader(ast_def, cc_namespace, ast_path); + auto ast_hdr_path = JoinPath(ast_path, "ast.generated.h"); + std::cout << "Writing ast_hdr to " << ast_hdr_path << "\n"; + MALDOCA_RETURN_IF_ERROR(SetFileContents(ast_hdr_path, ast_hdr)); + + std::string ast_src = PrintAstSource(ast_def, cc_namespace, ast_path); + auto ast_src_path = JoinPath(ast_path, "ast.generated.cc"); + std::cout << "Writing ast_src to " << ast_src_path << "\n"; + MALDOCA_RETURN_IF_ERROR(SetFileContents(ast_src_path, ast_src)); + + std::string ast_to_json = PrintAstToJson(ast_def, cc_namespace, ast_path); + auto ast_to_json_path = JoinPath(ast_path, "ast_to_json.generated.cc"); + std::cout << "Writing ast_to_json to " << ast_to_json_path << "\n"; + MALDOCA_RETURN_IF_ERROR( + SetFileContents(ast_to_json_path, ast_to_json)); + + std::string ast_from_json = PrintAstFromJson(ast_def, cc_namespace, ast_path); + auto ast_from_json_path = + JoinPath(ast_path, "ast_from_json.generated.cc"); + std::cout << "Writing ast_from_json to " << ast_from_json_path << "\n"; + MALDOCA_RETURN_IF_ERROR( + SetFileContents(ast_from_json_path, ast_from_json)); + + if (!ir_path.empty()) { + std::string ir_tablegen = PrintIrTableGen(ast_def, ir_path); + auto ir_tablegen_path = JoinPath( + ir_path, absl::StrCat(ast_def.lang_name(), "ir_ops.generated.td")); + std::cout << "Writing ir_tablegen to " << ir_tablegen_path << "\n"; + MALDOCA_RETURN_IF_ERROR( + SetFileContents(ir_tablegen_path, ir_tablegen)); + + std::string ast_to_ir = + PrintAstToIrSource(ast_def, cc_namespace, ast_path, ir_path); + auto ast_to_ir_path = JoinPath( + ir_path, "conversion", + absl::StrCat("ast_to_", ast_def.lang_name(), "ir.generated.cc")); + std::cout << "Writing ast_to_ir to " << ast_to_ir_path << "\n"; + MALDOCA_RETURN_IF_ERROR( + SetFileContents(ast_to_ir_path, ast_to_ir)); + + std::string ir_to_ast = + PrintIrToAstSource(ast_def, cc_namespace, ast_path, ir_path); + auto ir_to_ast_path = JoinPath( + ir_path, "conversion", + absl::StrCat(ast_def.lang_name(), "ir_to_ast.generated.cc")); + std::cout << "Writing ir_to_ast to " << ir_to_ast_path << "\n"; + MALDOCA_RETURN_IF_ERROR( + SetFileContents(ir_to_ast_path, ir_to_ast)); + } + + return absl::OkStatus(); +} + +} // namespace +} // namespace maldoca + +int main(int argc, char* argv[]) { + + auto status = maldoca::AstGenMain(); + if (!status.ok()) { + std::cerr << "Error: " << status.ToString() << std::endl; + return 1; + } + + return 0; +} diff --git a/maldoca/astgen/ast_gen_test.cc b/maldoca/astgen/ast_gen_test.cc new file mode 100644 index 0000000..81dd1c5 --- /dev/null +++ b/maldoca/astgen/ast_gen_test.cc @@ -0,0 +1,522 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/ast_gen.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "maldoca/astgen/ast_def.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/symbol.h" +#include "maldoca/astgen/type.h" +#include "maldoca/base/filesystem.h" +#include "maldoca/base/testing/status_matchers.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace maldoca { +namespace { + +struct PrintFieldDefTestCase { + const char *field_def; + const char *ts_interface_field; + const char *cc_member_variable; +}; + +void TestPrintFieldDef(const PrintFieldDefTestCase &test_case) { + FieldDefPb field_def_pb; + MALDOCA_ASSERT_OK(ParseTextProto(test_case.field_def, "test_case.field_def", + &field_def_pb)); + + MALDOCA_ASSERT_OK_AND_ASSIGN(auto field_def, + FieldDef::FromFieldDefPb(field_def_pb, "UsedLanguage")); + + std::string ts_interface_field; + { + google::protobuf::io::StringOutputStream os(&ts_interface_field); + TsInterfacePrinter printer(&os); + printer.PrintFieldDef(field_def); + } + EXPECT_EQ(ts_interface_field, test_case.ts_interface_field); + + std::string cc_member_variable; + { + google::protobuf::io::StringOutputStream os(&cc_member_variable); + AstHeaderPrinter printer(&os); + printer.PrintMemberVariable(field_def, "UsedLanguage"); + } + EXPECT_EQ(cc_member_variable, test_case.cc_member_variable); +} + +TEST(PrintFieldDef, BuiltinType) { + TestPrintFieldDef(PrintFieldDefTestCase{ + .field_def = R"pb( + name: "field" + type { bool {} } + )pb", + .ts_interface_field = "field: boolean\n", + .cc_member_variable = "bool field_;\n", + }); +} + +TEST(PrintFieldDef, TypeMaybeNull) { + TestPrintFieldDef(PrintFieldDefTestCase{ + .field_def = R"pb( + name: "field" + type { bool {} } + optionalness: OPTIONALNESS_MAYBE_NULL + )pb", + .ts_interface_field = "field: boolean | null\n", + .cc_member_variable = "std::optional field_;\n", + }); +} + +TEST(PrintFieldDef, TypeMaybeUndefined) { + TestPrintFieldDef(PrintFieldDefTestCase{ + .field_def = R"pb( + name: "field" + type { bool {} } + optionalness: OPTIONALNESS_MAYBE_UNDEFINED + )pb", + .ts_interface_field = "field?: boolean\n", + .cc_member_variable = "std::optional field_;\n", + }); +} + +TEST(PrintFieldDef, PrintMultiLineTitle) { + static const char kExpectedOutput[] = R"( +// ============================================================================= +// Title Line 1 +// Title Line 2 +// Title Line 3 +// ============================================================================= + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + CcPrinterBase printer(&os); + printer.PrintTitle("Title Line 1\nTitle Line 2\nTitle Line 3"); + } + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(PrintFieldDef, PrintMultiLineTitleWithEmptyLine) { + static const char kExpectedOutput[] = R"( +// ============================================================================= +// Title Line 1 +// +// Title Line 3 +// ============================================================================= + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + CcPrinterBase printer(&os); + printer.PrintTitle("Title Line 1\n\nTitle Line 3"); + } + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +// ============================================================================= +// AstFromJsonPrinter::PrintBuiltinFromJson() +// ============================================================================= + +TEST(AstFromJsonPrinterTest, TestPrintAssignBuiltinFromJson) { + static const char kExpectedOutput[] = R"( +my_lhs = my_rhs.get(); + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintBuiltinFromJson( + AstFromJsonPrinter::Action::kAssign, + AstFromJsonPrinter::CheckJsonType::kNo, + BuiltinType(BuiltinTypeKind::kString, "UsedLanguage"), Symbol("my_lhs"), + Symbol("my_rhs")); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintDefBuiltinFromJson) { + static const char kExpectedOutput[] = R"( +auto my_lhs = my_rhs.get(); + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintBuiltinFromJson( + AstFromJsonPrinter::Action::kDef, + AstFromJsonPrinter::CheckJsonType::kNo, + BuiltinType(BuiltinTypeKind::kString, "UsedLanguage"), Symbol("my_lhs"), + Symbol("my_rhs")); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintReturnBuiltinFromJson) { + static const char kExpectedOutput[] = R"( +return my_rhs.get(); + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintBuiltinFromJson( + AstFromJsonPrinter::Action::kReturn, + AstFromJsonPrinter::CheckJsonType::kNo, + BuiltinType(BuiltinTypeKind::kString, "UsedLanguage"), Symbol("my_lhs"), + Symbol("my_rhs")); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintAssignClassFromJson) { + static const char kExpectedOutput[] = R"( +MALDOCA_ASSIGN_OR_RETURN(my_lhs, UsedLanguageClassType::FromJson(my_rhs)); + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintClassFromJson(AstFromJsonPrinter::Action::kAssign, + ClassType(Symbol("ClassType"), "UsedLanguage"), + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintDefClassFromJson) { + static const char kExpectedOutput[] = R"( +MALDOCA_ASSIGN_OR_RETURN(auto my_lhs, UsedLanguageClassType::FromJson(my_rhs)); + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintClassFromJson(AstFromJsonPrinter::Action::kDef, + ClassType(Symbol("ClassType"), "UsedLanguage"), + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintReturnClassFromJson) { + static const char kExpectedOutput[] = R"( +return UsedLanguageClassType::FromJson(my_rhs); + )"; + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintClassFromJson(AstFromJsonPrinter::Action::kReturn, + ClassType(Symbol("ClassType"), "UsedLanguage"), + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +// ============================================================================= +// AstFromJsonPrinter::PrintVariantFromJson() +// ============================================================================= + +TEST(AstFromJsonPrinterTest, TestPrintAssignVariantFromJson) { + static const char kExpectedOutput[] = R"( +if (my_rhs.is_string()) { + my_lhs = my_rhs.get(); +} else if (IsClassType(my_rhs)) { + MALDOCA_ASSIGN_OR_RETURN(my_lhs, UsedLanguageClassType::FromJson(my_rhs)); +} else { + auto result = absl::InvalidArgumentError("my_rhs has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{my_rhs.dump()}); + return result; +} + )"; + + std::vector> types; + types.push_back( + absl::make_unique(BuiltinTypeKind::kString, "UsedLanguage")); + types.push_back( + absl::make_unique(Symbol("ClassType"), "UsedLanguage")); + auto variant_type = VariantType(std::move(types), "UsedLanguage"); + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintVariantFromJson(AstFromJsonPrinter::Action::kAssign, + variant_type, Symbol("my_lhs"), + Symbol("my_rhs"), "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintDefVariantFromJson) { + static const char kExpectedOutput[] = R"( +std::variant> my_lhs; +if (my_rhs.is_string()) { + my_lhs = my_rhs.get(); +} else if (IsClassType(my_rhs)) { + MALDOCA_ASSIGN_OR_RETURN(my_lhs, UsedLanguageClassType::FromJson(my_rhs)); +} else { + auto result = absl::InvalidArgumentError("my_rhs has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{my_rhs.dump()}); + return result; +} + )"; + + std::vector> types; + types.push_back( + absl::make_unique(BuiltinTypeKind::kString, "UsedLanguage")); + types.push_back( + absl::make_unique(Symbol("ClassType"), "UsedLanguage")); + auto variant_type = VariantType(std::move(types), "UsedLanguage"); + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintVariantFromJson(AstFromJsonPrinter::Action::kDef, variant_type, + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintReturnVariantFromJson) { + static const char kExpectedOutput[] = R"( +if (my_rhs.is_string()) { + return my_rhs.get(); +} else if (IsClassType(my_rhs)) { + return UsedLanguageClassType::FromJson(my_rhs); +} else { + auto result = absl::InvalidArgumentError("my_rhs has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{my_rhs.dump()}); + return result; +} + )"; + + std::vector> types; + types.push_back( + absl::make_unique(BuiltinTypeKind::kString, "UsedLanguage")); + types.push_back( + absl::make_unique(Symbol("ClassType"), "UsedLanguage")); + auto variant_type = VariantType(std::move(types), "UsedLanguage"); + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintVariantFromJson(AstFromJsonPrinter::Action::kReturn, + variant_type, Symbol("my_lhs"), + Symbol("my_rhs"), "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +// ============================================================================= +// AstFromJsonPrinter::PrintListFromJson() +// ============================================================================= + +TEST(AstFromJsonPrinterTest, TestPrintAssignListFromJson) { + static const char kExpectedOutput[] = R"( +if (!my_rhs.is_array()) { + return absl::InvalidArgumentError("my_rhs expected to be array."); +} + +std::vector my_lhs_value; +for (const nlohmann::json& my_rhs_element : my_rhs) { + if (my_rhs_element.is_null()) { + return absl::InvalidArgumentError("my_rhs_element is null."); + } + if (!my_rhs_element.is_string()) { + return absl::InvalidArgumentError("Expecting my_rhs_element.is_string()."); + } + auto my_lhs_element = my_rhs_element.get(); + my_lhs_value.push_back(std::move(my_lhs_element)); +} +my_lhs = std::move(my_lhs_value); + )"; + + auto element_type = + absl::make_unique(BuiltinTypeKind::kString, "UsedLanguage"); + auto list_type = + ListType(std::move(element_type), MaybeNull::kNo, "UsedLanguage"); + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintListFromJson(AstFromJsonPrinter::Action::kAssign, list_type, + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintDefListFromJson) { + static const char kExpectedOutput[] = R"( +if (!my_rhs.is_array()) { + return absl::InvalidArgumentError("my_rhs expected to be array."); +} + +std::vector my_lhs; +for (const nlohmann::json& my_rhs_element : my_rhs) { + if (my_rhs_element.is_null()) { + return absl::InvalidArgumentError("my_rhs_element is null."); + } + if (!my_rhs_element.is_string()) { + return absl::InvalidArgumentError("Expecting my_rhs_element.is_string()."); + } + auto my_lhs_element = my_rhs_element.get(); + my_lhs.push_back(std::move(my_lhs_element)); +} + )"; + + auto element_type = + absl::make_unique(BuiltinTypeKind::kString, "UsedLanguage"); + auto list_type = + ListType(std::move(element_type), MaybeNull::kNo, "UsedLanguage"); + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintListFromJson(AstFromJsonPrinter::Action::kDef, list_type, + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintReturnListFromJson) { + static const char kExpectedOutput[] = R"( +if (!my_rhs.is_array()) { + return absl::InvalidArgumentError("my_rhs expected to be array."); +} + +std::vector my_lhs; +for (const nlohmann::json& my_rhs_element : my_rhs) { + if (my_rhs_element.is_null()) { + return absl::InvalidArgumentError("my_rhs_element is null."); + } + if (!my_rhs_element.is_string()) { + return absl::InvalidArgumentError("Expecting my_rhs_element.is_string()."); + } + auto my_lhs_element = my_rhs_element.get(); + my_lhs.push_back(std::move(my_lhs_element)); +} +return my_lhs; + )"; + + auto element_type = + absl::make_unique(BuiltinTypeKind::kString, "UsedLanguage"); + auto list_type = + ListType(std::move(element_type), MaybeNull::kNo, "UsedLanguage"); + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintListFromJson(AstFromJsonPrinter::Action::kReturn, list_type, + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +TEST(AstFromJsonPrinterTest, TestPrintDefListOfOptionalElementsFromJson) { + static const char kExpectedOutput[] = R"( +if (!my_rhs.is_array()) { + return absl::InvalidArgumentError("my_rhs expected to be array."); +} + +std::vector> my_lhs; +for (const nlohmann::json& my_rhs_element : my_rhs) { + std::optional my_lhs_element; + if (!my_rhs_element.is_null()) { + if (!my_rhs_element.is_string()) { + return absl::InvalidArgumentError("Expecting my_rhs_element.is_string()."); + } + my_lhs_element = my_rhs_element.get(); + } + my_lhs.push_back(std::move(my_lhs_element)); +} + )"; + + auto element_type = + absl::make_unique(BuiltinTypeKind::kString, "UsedLanguage"); + auto list_type = + ListType(std::move(element_type), MaybeNull::kYes, "UsedLanguage"); + + std::string output; + { + google::protobuf::io::StringOutputStream os(&output); + AstFromJsonPrinter printer(&os); + printer.PrintListFromJson(AstFromJsonPrinter::Action::kDef, list_type, + Symbol("my_lhs"), Symbol("my_rhs"), + "UsedLanguage"); + } + + EXPECT_EQ(absl::StripAsciiWhitespace(output), + absl::StripAsciiWhitespace(kExpectedOutput)); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/symbol.cc b/maldoca/astgen/symbol.cc new file mode 100644 index 0000000..221ad09 --- /dev/null +++ b/maldoca/astgen/symbol.cc @@ -0,0 +1,229 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/symbol.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" + +namespace maldoca { + +Symbol::Symbol(absl::string_view str) { + // Transforms the string to lower case words. + bool should_create_new_word = true; + for (char ch : str) { + if (absl::ascii_isupper(ch)) { + words_.push_back(""); + words_.back().push_back(absl::ascii_tolower(ch)); + should_create_new_word = false; + + } else if (ch == '_') { + should_create_new_word = true; + + } else { + if (should_create_new_word) { + words_.push_back(""); + } + words_.back().push_back(ch); + should_create_new_word = false; + } + } + + // An unfortunate patchwork: If the input ends with '_', this '_' is included + // in the last word. + // + // Example: Let's say we want to define a field with the name "operator", we + // have to define the MLIR field in the TableGen file as "operator_". Then, + // the MLIR getter would become "getOperator_()" or "getOperator_Attr()". + // + // You might ask: given that the MLIR getter is already prefixed with "get", + // why do we still need the "_"? Well, MLIR still generates the builder + // argument name unchanged. + // + // Now, we need a way to turn "operator" into "getOperator_Attr". + // This is done by: + // (Symbol("get") + Symbol("operator").ToCcVarName() + "attr").ToCamelCase() + if (!str.empty() && str.back() == '_') { + words_.back().push_back('_'); + } +} + +Symbol &Symbol::operator+=(const Symbol &other) { + words_.insert(words_.end(), other.words_.begin(), other.words_.end()); + return *this; +} + +Symbol &Symbol::operator+=(Symbol &&other) { + words_.insert(words_.end(), std::make_move_iterator(other.words_.begin()), + std::make_move_iterator(other.words_.end())); + return *this; +} + +Symbol &Symbol::operator+=(absl::string_view other) { + return operator+=(Symbol(other)); +} + +std::string Symbol::ToSnakeCase() const { return absl::StrJoin(words_, "_"); } + +bool Symbol::IsReservedKeyword() const { + // https://en.cppreference.com/w/cpp/keyword + static const auto *kReservedKeywords = new absl::flat_hash_set{ + "alignas", + "alignof", + "and", + "and_eq", + "asm", + "atomic_cancel", + "atomic_commit", + "atomic_noexcept", + "auto", + "bitand", + "bitor", + "bool", + "break", + "case", + "catch", + "char", + "char8_t", + "char16_t", + "char32_t", + "class", + "compl", + "concept", + "const", + "consteval", + "constexpr", + "constinit", + "const_cast", + "continue", + "co_await", + "co_return", + "co_yield", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "float", + "for", + "friend", + "goto", + "if", + "inline", + "int", + "long", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "private", + "protected", + "public", + "reflexpr", + "register", + "reinterpret_cast", + "requires", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "synchronized", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq", + + // Since https://reviews.llvm.org/D141742, "properties" cannot be an + // argument name in an MLIR op. + "properties", + }; + + std::string snake_case = ToSnakeCase(); + return kReservedKeywords->contains(snake_case); +} + +std::string Symbol::ToCcVarName() const { + std::string result = ToSnakeCase(); + if (IsReservedKeyword()) { + result.push_back('_'); + } + return result; +} + +std::string Symbol::ToMlirGetter() const { + std::string result = absl::StrCat("get", ToPascalCase()); + if (IsReservedKeyword()) { + result.push_back('_'); + } + return result; +} + +std::string Symbol::ToPascalCase() const { + std::string result; + for (const auto &word : words_) { + result.push_back(absl::ascii_toupper(word[0])); + result.append(word.substr(1)); + } + return result; +} + +std::string Symbol::ToCamelCase() const { + std::string result; + for (const auto &word : words_) { + if (result.empty()) { + result = word; + } else { + result.push_back(absl::ascii_toupper(word[0])); + result.append(word.substr(1)); + } + } + return result; +} + +} // namespace maldoca diff --git a/maldoca/astgen/symbol.h b/maldoca/astgen/symbol.h new file mode 100644 index 0000000..fe02a8d --- /dev/null +++ b/maldoca/astgen/symbol.h @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_SYMBOL_H_ +#define MALDOCA_ASTGEN_SYMBOL_H_ + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" + +namespace maldoca { + +// Models a list of words, and supports printing in snake_case, PascalCase, and +// camelCase. +// +// Also supports concatenation. +// +// For example, if a field is named "sourceType", then: +// - C++ variable name: source_type (printing snake_case) +// - Protobuf field name: source_type (printing snake_case) +// - C++ setter function name: get_source_type (concatenation) +// - JavaScript field name: sourceType (printing camelCase) +// - JSPB getter/setter: {get,set}SourceType (concatenation, printing camelCase) +class Symbol { + public: + // Input can be either snake_case, PascalCase, or camelCase. + explicit Symbol(absl::string_view str = ""); + + // Concatenation. + // E.g. "one_two" + "three_four" => "one_two_three_four" + Symbol &operator+=(const Symbol &other); + Symbol &operator+=(Symbol &&other); + + Symbol &operator+=(absl::string_view other); + + template + Symbol operator+(T &&other) const { + Symbol words = *this; + words += other; + return words; + } + + // "snake_case" + std::string ToSnakeCase() const; + + // Same as snake_case, but adds a '_' if collides with a reserved keyword. + // + // E.g. Symbol("static").ToCcVarName() => "static_" + std::string ToCcVarName() const; + + // "PascalCase" + std::string ToPascalCase() const; + + // "getPascalCase", but adds a '_' if the field name collides with a reserved + // keyword. + // + // E.g. Symbol("static").ToMlirGetter() => "getStatic_" + std::string ToMlirGetter() const; + + // "camelCase" + std::string ToCamelCase() const; + + private: + bool IsReservedKeyword() const; + + // Always store in lower case. + std::vector words_; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_SYMBOL_H_ diff --git a/maldoca/astgen/symbol_test.cc b/maldoca/astgen/symbol_test.cc new file mode 100644 index 0000000..7bd6c53 --- /dev/null +++ b/maldoca/astgen/symbol_test.cc @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/symbol.h" + +#include "gtest/gtest.h" + +namespace maldoca { +namespace { + +TEST(SymbolTest, FromPascalCase) { + Symbol symbol("GetLeftHandSide"); + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side"); +} + +TEST(SymbolTest, FromCamelCase) { + Symbol symbol("getLeftHandSide"); + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side"); +} + +TEST(SymbolTest, FromSnakeCase) { + Symbol symbol("get_left_hand_side"); + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side"); +} + +TEST(SymbolTest, ExtraUnderscoresAreIgnored) { + Symbol symbol("_get_left_hand_side"); + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side"); + + symbol = Symbol("get_left_hand_side_"); + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide_"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide_"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side_"); + + symbol = Symbol("get__left_hand_side"); + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side"); +} + +TEST(SymbolTest, ConcatenateSymbols) { + Symbol first("get_left"); + Symbol second("HandSide"); + Symbol symbol = first + second; + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side"); +} + +TEST(SymbolTest, ConcatenateSymbolWithString) { + Symbol first("get_left"); + auto second = "HandSide"; + Symbol symbol = first + second; + EXPECT_EQ(symbol.ToPascalCase(), "GetLeftHandSide"); + EXPECT_EQ(symbol.ToCamelCase(), "getLeftHandSide"); + EXPECT_EQ(symbol.ToSnakeCase(), "get_left_hand_side"); +} + +TEST(SymbolTest, AvoidCppKeyword) { + Symbol symbol("operator"); + EXPECT_EQ(symbol.ToCcVarName(), "operator_"); + EXPECT_EQ(Symbol(symbol.ToCcVarName()).ToPascalCase(), "Operator_"); + EXPECT_EQ((Symbol("get") + symbol.ToCcVarName() + "attr").ToCamelCase(), + "getOperator_Attr"); + EXPECT_EQ((Symbol("get") + symbol.ToCcVarName() + "attr").ToSnakeCase(), + "get_operator__attr"); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/BUILD b/maldoca/astgen/test/BUILD new file mode 100644 index 0000000..76f8d28 --- /dev/null +++ b/maldoca/astgen/test/BUILD @@ -0,0 +1,61 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//:__subpackages__", + ], +) + +cc_library( + name = "ast_gen_test_util", + testonly = True, + srcs = ["ast_gen_test_util.cc"], + hdrs = ["ast_gen_test_util.h"], + deps = [ + "//maldoca/astgen:ast_def", + "//maldoca/astgen:ast_def_cc_proto", + "//maldoca/astgen:ast_gen", + "//maldoca/base:filesystem", + "//maldoca/base:get_runfiles_dir", + "//maldoca/base:status", + "//maldoca/base/testing:status_matchers", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:optional", + "@googletest//:gtest", + ], +) + +cc_library( + name = "conversion_test_util", + testonly = True, + hdrs = ["conversion_test_util.h"], + deps = [ + "//maldoca/base/testing:status_matchers", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@nlohmann_json//:json", + ], +) diff --git a/maldoca/astgen/test/assign/BUILD b/maldoca/astgen/test/assign/BUILD new file mode 100644 index 0000000..48d7459 --- /dev/null +++ b/maldoca/astgen/test/assign/BUILD @@ -0,0 +1,184 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//maldoca/astgen:__subpackages__", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "air_ops.generated.td", + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + "//maldoca/astgen/test/assign/conversion:air_to_ast.generated.cc", + "//maldoca/astgen/test/assign/conversion:ast_to_air.generated.cc", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = ["ast.generated.h"], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) + +td_library( + name = "interfaces_td_files", + srcs = [ + "interfaces.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "interfaces_inc_gen", + tbl_outs = { + "interfaces.h.inc": ["-gen-op-interface-decls"], + "interfaces.cc.inc": ["-gen-op-interface-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "interfaces.td", + deps = [":interfaces_td_files"], +) + +td_library( + name = "air_dialect_td_files", + srcs = [ + "air_dialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "air_dialect_inc_gen", + tbl_outs = { + "air_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=air", + ], + "air_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=air", + ], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "air_dialect.td", + deps = [":air_dialect_td_files"], +) + +td_library( + name = "air_types_td_files", + srcs = [ + "air_types.td", + ], + deps = [ + ":air_dialect_td_files", + ], +) + +gentbl_cc_library( + name = "air_types_inc_gen", + tbl_outs = { + "air_types.h.inc": ["-gen-typedef-decls"], + "air_types.cc.inc": ["-gen-typedef-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "air_types.td", + deps = [":air_types_td_files"], +) + +td_library( + name = "air_ops_generated_td_files", + srcs = [ + "air_ops.generated.td", + ], + deps = [ + ":air_dialect_td_files", + ":air_types_td_files", + ":interfaces_td_files", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "air_ops_generated_inc_gen", + tbl_outs = { + "air_ops.generated.h.inc": ["-gen-op-decls"], + "air_ops.generated.cc.inc": ["-gen-op-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "air_ops.generated.td", + deps = [":air_ops_generated_td_files"], +) + +cc_library( + name = "ir", + srcs = ["ir.cc"], + hdrs = ["ir.h"], + deps = [ + ":air_dialect_inc_gen", + ":air_ops_generated_inc_gen", + ":air_types_inc_gen", + ":interfaces_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + ], +) diff --git a/maldoca/astgen/test/assign/air_dialect.td b/maldoca/astgen/test/assign/air_dialect.td new file mode 100644 index 0000000..a078759 --- /dev/null +++ b/maldoca/astgen/test/assign/air_dialect.td @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ASSIGN_AIR_DIALECT_TD_ +#define MALDOCA_ASTGEN_TEST_ASSIGN_AIR_DIALECT_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def Air_Dialect : Dialect { + let name = "air"; + let cppNamespace = "::maldoca"; + + let description = [{ + The AssignIR, a test IR that models assignments. All ops and fields are + directly mapped from the AST. + + In AIR, we represent Identifiers with two ops: AirIdentifierOp for rvalue + and AirIdentifierRefOp for lvalue. + }]; + + let useDefaultTypePrinterParser = 1; +} + +class Air_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = ?; +} + +class Air_Op traits = []> : + Op; + +#endif // MALDOCA_ASTGEN_TEST_ASSIGN_AIR_DIALECT_TD_ diff --git a/maldoca/astgen/test/assign/air_ops.generated.td b/maldoca/astgen/test/assign/air_ops.generated.td new file mode 100644 index 0000000..65b505d --- /dev/null +++ b/maldoca/astgen/test/assign/air_ops.generated.td @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_ASSIGN_AIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_ASSIGN_AIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/assign/interfaces.td" +include "maldoca/astgen/test/assign/air_dialect.td" +include "maldoca/astgen/test/assign/air_types.td" + +def AirIdentifierOp : Air_Op< + "identifier", [ + AirExpressionOpInterfaceTraits + ]> { + let arguments = (ins + StrAttr: $name + ); + + let results = (outs + AirAnyType + ); +} + +def AirIdentifierRefOp : Air_Op<"identifier_ref", []> { + let arguments = (ins + StrAttr: $name + ); + + let results = (outs + AirAnyType + ); +} + +def AirAssignmentOp : Air_Op< + "assignment", [ + AirExpressionOpInterfaceTraits + ]> { + let arguments = (ins + AnyType: $lhs, + AnyType: $rhs + ); + + let results = (outs + AirAnyType + ); +} + +#endif // MALDOCA_ASTGEN_TEST_ASSIGN_AIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/assign/air_types.td b/maldoca/astgen/test/assign/air_types.td new file mode 100644 index 0000000..7dea550 --- /dev/null +++ b/maldoca/astgen/test/assign/air_types.td @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ASSIGN_AIR_TYPES_TD_ +#define MALDOCA_ASTGEN_TEST_ASSIGN_AIR_TYPES_TD_ + +include "maldoca/astgen/test/assign/air_dialect.td" + +def AirAnyType : Air_Type<"AirAny"> { + let summary = "A placeholder singleton type."; + let mnemonic = "any"; + let assemblyFormat = ""; +} + +#endif // MALDOCA_ASTGEN_TEST_ASSIGN_AIR_TYPES_TD_ diff --git a/maldoca/astgen/test/assign/ast.generated.cc b/maldoca/astgen/test/assign/ast.generated.cc new file mode 100644 index 0000000..81163d9 --- /dev/null +++ b/maldoca/astgen/test/assign/ast.generated.cc @@ -0,0 +1,127 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/assign/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// AExpression +// ============================================================================= + +absl::string_view AExpressionTypeToString(AExpressionType expression_type) { + switch (expression_type) { + case AExpressionType::kIdentifier: + return "Identifier"; + case AExpressionType::kAssignment: + return "Assignment"; + } +} + +absl::StatusOr StringToAExpressionType(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"Identifier", AExpressionType::kIdentifier}, + {"Assignment", AExpressionType::kAssignment}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for AExpressionType: ", s)); + } + return it->second; +} + +// ============================================================================= +// AIdentifier +// ============================================================================= + +AIdentifier::AIdentifier( + std::string name) + : AExpression(), + name_(std::move(name)) {} + +absl::string_view AIdentifier::name() const { + return name_; +} + +void AIdentifier::set_name(std::string name) { + name_ = std::move(name); +} + +// ============================================================================= +// AAssignment +// ============================================================================= + +AAssignment::AAssignment( + std::unique_ptr lhs, + std::unique_ptr rhs) + : AExpression(), + lhs_(std::move(lhs)), + rhs_(std::move(rhs)) {} + +AIdentifier* AAssignment::lhs() { + return lhs_.get(); +} + +const AIdentifier* AAssignment::lhs() const { + return lhs_.get(); +} + +void AAssignment::set_lhs(std::unique_ptr lhs) { + lhs_ = std::move(lhs); +} + +AExpression* AAssignment::rhs() { + return rhs_.get(); +} + +const AExpression* AAssignment::rhs() const { + return rhs_.get(); +} + +void AAssignment::set_rhs(std::unique_ptr rhs) { + rhs_ = std::move(rhs); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/assign/ast.generated.h b/maldoca/astgen/test/assign/ast.generated.h new file mode 100644 index 0000000..0c87bb5 --- /dev/null +++ b/maldoca/astgen/test/assign/ast.generated.h @@ -0,0 +1,136 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_ASSIGN_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_ASSIGN_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +enum class AExpressionType { + kIdentifier, + kAssignment, +}; + +absl::string_view AExpressionTypeToString(AExpressionType expression_type); +absl::StatusOr StringToAExpressionType(absl::string_view s); + +class AExpression { + public: + virtual ~AExpression() = default; + + virtual AExpressionType expression_type() const = 0; + + virtual void Serialize(std::ostream& os) const = 0; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class AIdentifier : public virtual AExpression { + public: + explicit AIdentifier( + std::string name); + + AExpressionType expression_type() const override { + return AExpressionType::kIdentifier; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + absl::string_view name() const; + void set_name(std::string name); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetName(const nlohmann::json& json); + + private: + std::string name_; +}; + +class AAssignment : public virtual AExpression { + public: + explicit AAssignment( + std::unique_ptr lhs, + std::unique_ptr rhs); + + AExpressionType expression_type() const override { + return AExpressionType::kAssignment; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + AIdentifier* lhs(); + const AIdentifier* lhs() const; + void set_lhs(std::unique_ptr lhs); + + AExpression* rhs(); + const AExpression* rhs() const; + void set_rhs(std::unique_ptr rhs); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetLhs(const nlohmann::json& json); + static absl::StatusOr> GetRhs(const nlohmann::json& json); + + private: + std::unique_ptr lhs_; + std::unique_ptr rhs_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_ASSIGN_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/assign/ast_def.textproto b/maldoca/astgen/test/assign/ast_def.textproto new file mode 100644 index 0000000..123a5e4 --- /dev/null +++ b/maldoca/astgen/test/assign/ast_def.textproto @@ -0,0 +1,47 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "a" + +# interface Expression {} +nodes { + name: "Expression" + kinds: FIELD_KIND_RVAL +} + +# interface Identifier <: Expression { +# name: string +# } +nodes { + name: "Identifier" + type: "Identifier" + parents: "Expression" + fields { + name: "name" + type { string {} } + kind: FIELD_KIND_ATTR + } + kinds: FIELD_KIND_LVAL + should_generate_ir_op: true +} + +# interface Assignment <: Expression { +# lhs: Identifier +# rhs: Expression +# } +nodes { + name: "Assignment" + type: "Assignment" + parents: "Expression" + fields { + name: "lhs" + type { class: "Identifier" } + kind: FIELD_KIND_LVAL + } + fields { + name: "rhs" + type { class: "Expression" } + kind: FIELD_KIND_RVAL + } + should_generate_ir_op: true +} diff --git a/maldoca/astgen/test/assign/ast_from_json.generated.cc b/maldoca/astgen/test/assign/ast_from_json.generated.cc new file mode 100644 index 0000000..1516eda --- /dev/null +++ b/maldoca/astgen/test/assign/ast_from_json.generated.cc @@ -0,0 +1,161 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/assign/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +static absl::StatusOr GetType(const nlohmann::json& json) { + auto type_it = json.find("type"); + if (type_it == json.end()) { + return absl::InvalidArgumentError("`type` is undefined."); + } + const nlohmann::json& json_type = type_it.value(); + if (json_type.is_null()) { + return absl::InvalidArgumentError("json_type is null."); + } + if (!json_type.is_string()) { + return absl::InvalidArgumentError("`json_type` expected to be string."); + } + return json_type.get(); +} + +// ============================================================================= +// AExpression +// ============================================================================= + +absl::StatusOr> +AExpression::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "Identifier") { + return AIdentifier::FromJson(json); + } else if (type == "Assignment") { + return AAssignment::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// AIdentifier +// ============================================================================= + +absl::StatusOr +AIdentifier::GetName(const nlohmann::json& json) { + auto name_it = json.find("name"); + if (name_it == json.end()) { + return absl::InvalidArgumentError("`name` is undefined."); + } + const nlohmann::json& json_name = name_it.value(); + + if (json_name.is_null()) { + return absl::InvalidArgumentError("json_name is null."); + } + if (!json_name.is_string()) { + return absl::InvalidArgumentError("Expecting json_name.is_string()."); + } + return json_name.get(); +} + +absl::StatusOr> +AIdentifier::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto name, AIdentifier::GetName(json)); + + return absl::make_unique( + std::move(name)); +} + +// ============================================================================= +// AAssignment +// ============================================================================= + +absl::StatusOr> +AAssignment::GetLhs(const nlohmann::json& json) { + auto lhs_it = json.find("lhs"); + if (lhs_it == json.end()) { + return absl::InvalidArgumentError("`lhs` is undefined."); + } + const nlohmann::json& json_lhs = lhs_it.value(); + + if (json_lhs.is_null()) { + return absl::InvalidArgumentError("json_lhs is null."); + } + return AIdentifier::FromJson(json_lhs); +} + +absl::StatusOr> +AAssignment::GetRhs(const nlohmann::json& json) { + auto rhs_it = json.find("rhs"); + if (rhs_it == json.end()) { + return absl::InvalidArgumentError("`rhs` is undefined."); + } + const nlohmann::json& json_rhs = rhs_it.value(); + + if (json_rhs.is_null()) { + return absl::InvalidArgumentError("json_rhs is null."); + } + return AExpression::FromJson(json_rhs); +} + +absl::StatusOr> +AAssignment::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto lhs, AAssignment::GetLhs(json)); + MALDOCA_ASSIGN_OR_RETURN(auto rhs, AAssignment::GetRhs(json)); + + return absl::make_unique( + std::move(lhs), + std::move(rhs)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/assign/ast_gen_test.cc b/maldoca/astgen/test/assign/ast_gen_test.cc new file mode 100644 index 0000000..568c194 --- /dev/null +++ b/maldoca/astgen/test/assign/ast_gen_test.cc @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + Assign, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = + "maldoca/astgen/test/assign/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "assign/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/assign", + .ir_path = "maldoca/astgen/test/assign", + .expected_ast_header_path = + "maldoca/astgen/test/assign/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/assign/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "assign/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "assign/ast_from_json.generated.cc", + .expected_ir_tablegen_path = + "maldoca/astgen/test/" + "assign/air_ops.generated.td", + .expected_ast_to_ir_source_path = + "maldoca/astgen/test/assign/conversion/" + "ast_to_air.generated.cc", + .expected_ir_to_ast_source_path = + "maldoca/astgen/test/assign/conversion/" + "air_to_ast.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/assign/ast_to_json.generated.cc b/maldoca/astgen/test/assign/ast_to_json.generated.cc new file mode 100644 index 0000000..fe0cf91 --- /dev/null +++ b/maldoca/astgen/test/assign/ast_to_json.generated.cc @@ -0,0 +1,103 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/assign/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// AExpression +// ============================================================================= + +void AExpression::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +// ============================================================================= +// AIdentifier +// ============================================================================= + +void AIdentifier::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"name\":" << (nlohmann::json(name_)).dump(); +} + +void AIdentifier::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"Identifier\""; + AExpression::SerializeFields(os, needs_comma); + AIdentifier::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// AAssignment +// ============================================================================= + +void AAssignment::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"lhs\":"; + lhs_->Serialize(os); + MaybeAddComma(os, needs_comma); + os << "\"rhs\":"; + rhs_->Serialize(os); +} + +void AAssignment::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"Assignment\""; + AExpression::SerializeFields(os, needs_comma); + AAssignment::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/assign/ast_ts_interface.generated b/maldoca/astgen/test/assign/ast_ts_interface.generated new file mode 100644 index 0000000..27205b0 --- /dev/null +++ b/maldoca/astgen/test/assign/ast_ts_interface.generated @@ -0,0 +1,11 @@ +interface Expression { +} + +interface Identifier <: Expression { + name: string +} + +interface Assignment <: Expression { + lhs: Identifier + rhs: Expression +} diff --git a/maldoca/astgen/test/assign/conversion/BUILD b/maldoca/astgen/test/assign/conversion/BUILD new file mode 100644 index 0000000..8f43991 --- /dev/null +++ b/maldoca/astgen/test/assign/conversion/BUILD @@ -0,0 +1,76 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_applicable_licenses = ["//:license"]) + +licenses(["notice"]) + +exports_files([ + "ast_to_air.generated.cc", + "air_to_ast.generated.cc", +]) + +cc_library( + name = "ast_to_air", + srcs = ["ast_to_air.generated.cc"], + hdrs = ["ast_to_air.h"], + deps = [ + "//maldoca/astgen/test/assign:ast", + "//maldoca/astgen/test/assign:ir", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "air_to_ast", + srcs = ["air_to_ast.generated.cc"], + hdrs = ["air_to_ast.h"], + deps = [ + "//maldoca/astgen/test/assign:ast", + "//maldoca/astgen/test/assign:ir", + "//maldoca/base:status", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_test( + name = "conversion_test", + srcs = ["conversion_test.cc"], + deps = [ + ":air_to_ast", + ":ast_to_air", + "//maldoca/astgen/test:conversion_test_util", + "//maldoca/astgen/test/assign:ast", + "//maldoca/astgen/test/assign:ir", + "@googletest//:gtest_main", + ], +) diff --git a/maldoca/astgen/test/assign/conversion/air_to_ast.generated.cc b/maldoca/astgen/test/assign/conversion/air_to_ast.generated.cc new file mode 100644 index 0000000..d2cced5 --- /dev/null +++ b/maldoca/astgen/test/assign/conversion/air_to_ast.generated.cc @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/assign/conversion/air_to_ast.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/base/status_macros.h" +#include "maldoca/astgen/test/assign/ast.generated.h" +#include "maldoca/astgen/test/assign/ir.h" + +namespace maldoca { + +absl::StatusOr> +AirToAst::VisitExpression(AirExpressionOpInterface op) { + using Ret = absl::StatusOr>; + return llvm::TypeSwitch(op) + .Case([&](AirIdentifierOp op) { + return VisitIdentifier(op); + }) + .Case([&](AirAssignmentOp op) { + return VisitAssignment(op); + }) + .Default([&](mlir::Operation* op) { + return absl::InvalidArgumentError("Unrecognized op"); + }); +} + +absl::StatusOr> +AirToAst::VisitIdentifier(AirIdentifierOp op) { + std::string name = op.getNameAttr().str(); + return Create( + op, + std::move(name)); +} + +absl::StatusOr> +AirToAst::VisitIdentifierRef(AirIdentifierRefOp op) { + std::string name = op.getNameAttr().str(); + return Create( + op, + std::move(name)); +} + +absl::StatusOr> +AirToAst::VisitAssignment(AirAssignmentOp op) { + auto lhs_op = llvm::dyn_cast(op.getLhs().getDefiningOp()); + if (lhs_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected AirIdentifierRefOp, got ", + op.getLhs().getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr lhs, VisitIdentifierRef(lhs_op)); + auto rhs_op = llvm::dyn_cast(op.getRhs().getDefiningOp()); + if (rhs_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected AirExpressionOpInterface, got ", + op.getRhs().getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr rhs, VisitExpression(rhs_op)); + return Create( + op, + std::move(lhs), + std::move(rhs)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/assign/conversion/air_to_ast.h b/maldoca/astgen/test/assign/conversion/air_to_ast.h new file mode 100644 index 0000000..ae51d54 --- /dev/null +++ b/maldoca/astgen/test/assign/conversion/air_to_ast.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ASSIGN_CONVERSION_AIR_TO_AST_H_ +#define MALDOCA_ASTGEN_TEST_ASSIGN_CONVERSION_AIR_TO_AST_H_ + +#include + +#include "mlir/IR/Operation.h" +#include "absl/status/statusor.h" +#include "maldoca/astgen/test/assign/ast.generated.h" +#include "maldoca/astgen/test/assign/ir.h" + +namespace maldoca { + +class AirToAst { + public: + absl::StatusOr> VisitExpression( + AirExpressionOpInterface op); + + absl::StatusOr> VisitIdentifier( + AirIdentifierOp op); + + absl::StatusOr> VisitIdentifierRef( + AirIdentifierRefOp op); + + absl::StatusOr> VisitAssignment( + AirAssignmentOp op); + + template + std::unique_ptr Create(mlir::Operation *op, Args &&...args) { + return absl::make_unique(std::forward(args)...); + } +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_ASSIGN_CONVERSION_AIR_TO_AST_H_ diff --git a/maldoca/astgen/test/assign/conversion/ast_to_air.generated.cc b/maldoca/astgen/test/assign/conversion/ast_to_air.generated.cc new file mode 100644 index 0000000..77e875b --- /dev/null +++ b/maldoca/astgen/test/assign/conversion/ast_to_air.generated.cc @@ -0,0 +1,78 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/assign/conversion/ast_to_air.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/astgen/test/assign/ast.generated.h" +#include "maldoca/astgen/test/assign/ir.h" + +namespace maldoca { + +AirExpressionOpInterface AstToAir::VisitExpression(const AExpression *node) { + if (auto *identifier = dynamic_cast(node)) { + return VisitIdentifier(identifier); + } + if (auto *assignment = dynamic_cast(node)) { + return VisitAssignment(assignment); + } + LOG(FATAL) << "Unreachable code."; +} + +AirIdentifierOp AstToAir::VisitIdentifier(const AIdentifier *node) { + mlir::StringAttr mlir_name = builder_.getStringAttr(node->name()); + return CreateExpr(node, mlir_name); +} + +AirIdentifierRefOp AstToAir::VisitIdentifierRef(const AIdentifier *node) { + mlir::StringAttr mlir_name = builder_.getStringAttr(node->name()); + return CreateExpr(node, mlir_name); +} + +AirAssignmentOp AstToAir::VisitAssignment(const AAssignment *node) { + mlir::Value mlir_lhs = VisitIdentifierRef(node->lhs()); + mlir::Value mlir_rhs = VisitExpression(node->rhs()); + return CreateExpr(node, mlir_lhs, mlir_rhs); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/assign/conversion/ast_to_air.h b/maldoca/astgen/test/assign/conversion/ast_to_air.h new file mode 100644 index 0000000..5d9b27d --- /dev/null +++ b/maldoca/astgen/test/assign/conversion/ast_to_air.h @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ASSIGN_CONVERSION_AST_TO_AIR_H_ +#define MALDOCA_ASTGEN_TEST_ASSIGN_CONVERSION_AST_TO_AIR_H_ + +#include "mlir/IR/Builders.h" +#include "maldoca/astgen/test/assign/ast.generated.h" +#include "maldoca/astgen/test/assign/ir.h" + +namespace maldoca { + +class AstToAir { + public: + explicit AstToAir(mlir::OpBuilder &builder) : builder_(builder) {} + + AirIdentifierOp VisitIdentifier(const AIdentifier *node); + + AirIdentifierRefOp VisitIdentifierRef(const AIdentifier *node); + + AirAssignmentOp VisitAssignment(const AAssignment *node); + + AirExpressionOpInterface VisitExpression(const AExpression *node); + + private: + template + Op CreateExpr(const Node *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), + std::forward(args)...); + } + + mlir::OpBuilder &builder_; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_ASSIGN_CONVERSION_AST_TO_AIR_H_ diff --git a/maldoca/astgen/test/assign/conversion/conversion_test.cc b/maldoca/astgen/test/assign/conversion/conversion_test.cc new file mode 100644 index 0000000..71d59bf --- /dev/null +++ b/maldoca/astgen/test/assign/conversion/conversion_test.cc @@ -0,0 +1,101 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/assign/ast.generated.h" +#include "maldoca/astgen/test/assign/conversion/air_to_ast.h" +#include "maldoca/astgen/test/assign/conversion/ast_to_air.h" +#include "maldoca/astgen/test/assign/ir.h" +#include "maldoca/astgen/test/conversion_test_util.h" + +namespace maldoca { +namespace { + +TEST(ConversionTest, SimpleAssignment) { + // a = b + constexpr char kAstJsonString[] = R"( + { + "type": "Assignment", + "lhs": { + "type": "Identifier", + "name": "a" + }, + "rhs": { + "type": "Identifier", + "name": "b" + } + } + )"; + + constexpr char kExpectedIrDump[] = R"( +module { + %0 = "air.identifier_ref"() <{name = "a"}> : () -> !air.any + %1 = "air.identifier"() <{name = "b"}> : () -> !air.any + %2 = "air.assignment"(%0, %1) : (!air.any, !air.any) -> !air.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToAir::VisitAssignment, + .ir_to_ast_visit = &AirToAst::VisitAssignment, + .expected_ir_dump = kExpectedIrDump, + }); +} + +TEST(ConversionTest, ChainAssignment) { + // a = (b = c) + constexpr char kAstJsonString[] = R"( + { + "type": "Assignment", + "lhs": { + "type": "Identifier", + "name": "a" + }, + "rhs": { + "type": "Assignment", + "lhs": { + "type": "Identifier", + "name": "b" + }, + "rhs": { + "type": "Identifier", + "name": "c" + } + } + } + )"; + + constexpr char kExpectedIrDump[] = R"( +module { + %0 = "air.identifier_ref"() <{name = "a"}> : () -> !air.any + %1 = "air.identifier_ref"() <{name = "b"}> : () -> !air.any + %2 = "air.identifier"() <{name = "c"}> : () -> !air.any + %3 = "air.assignment"(%1, %2) : (!air.any, !air.any) -> !air.any + %4 = "air.assignment"(%0, %3) : (!air.any, !air.any) -> !air.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToAir::VisitAssignment, + .ir_to_ast_visit = &AirToAst::VisitAssignment, + .expected_ir_dump = kExpectedIrDump, + }); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/assign/interfaces.td b/maldoca/astgen/test/assign/interfaces.td new file mode 100644 index 0000000..9740a13 --- /dev/null +++ b/maldoca/astgen/test/assign/interfaces.td @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definition of op interfaces used by the air dialect. +// +// Just like we model leaf classes as MLIR ops, we model non-leaf classes as +// MLIR interfaces. +// +// For example, `Identifier` inherits from `Expression`, so we define an +// interface `AirExpressionOpInterface`. +// +// This way, we can implicitly convert an `AirIdentifierOp` to an +// `AirExpressionOpInterface`: +// +// ``` +// AirIdentifierOp identifier = ...; +// AirExpressionOpInterface expression = identifier; +// ``` +// +// We can also type check and explicitly convert an `AirExpressionOpInterface` +// to an `AirIdentifierOp`: +// +// ``` +// AirExpressionOpInterface expression = ...; +// if (llvm::isa(expression)) { +// auto identifier = llvm::cast(identifier); +// ... +// } +// ``` + +include "mlir/IR/OpBase.td" + +def AirExpressionOpInterface : OpInterface<"AirExpressionOpInterface"> { + let cppNamespace = "::maldoca"; + + let extraClassDeclaration = [{ + operator mlir::Value() { // NOLINT + return getOperation()->getResult(0); + } + }]; +} + +def AirExpressionOpInterfaceTraits : TraitList<[ + DeclareOpInterfaceMethods +]>; + +def AirExpressionRefOpInterface : OpInterface<"AirExpressionRefOpInterface"> { + let cppNamespace = "::maldoca"; + + let extraClassDeclaration = [{ + operator mlir::Value() { // NOLINT + return getOperation()->getResult(0); + } + }]; +} + +def AirExpressionRefOpTraits : TraitList<[ + DeclareOpInterfaceMethods +]>; diff --git a/maldoca/astgen/test/assign/ir.cc b/maldoca/astgen/test/assign/ir.cc new file mode 100644 index 0000000..f26d91b --- /dev/null +++ b/maldoca/astgen/test/assign/ir.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/test/assign/ir.h" + +// IWYU pragma: begin_keep + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +// IWYU pragma: end_keep + +// ============================================================================= +// Dialect Definition +// ============================================================================= + +#include "maldoca/astgen/test/assign/air_dialect.cc.inc" + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void maldoca::AirDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "maldoca/astgen/test/assign/air_types.cc.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "maldoca/astgen/test/assign/air_ops.generated.cc.inc" + >(); +} + +// ============================================================================= +// Dialect Interface Definitions +// ============================================================================= + +#include "maldoca/astgen/test/assign/interfaces.cc.inc" + +// ============================================================================= +// Dialect Type Definitions +// ============================================================================= + +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/assign/air_types.cc.inc" + +// ============================================================================= +// Dialect Op Definitions +// ============================================================================= + +#define GET_OP_CLASSES +#include "maldoca/astgen/test/assign/air_ops.generated.cc.inc" diff --git a/maldoca/astgen/test/assign/ir.h b/maldoca/astgen/test/assign/ir.h new file mode 100644 index 0000000..5ba6dda --- /dev/null +++ b/maldoca/astgen/test/assign/ir.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ASSIGN_IR_H_ +#define MALDOCA_ASTGEN_TEST_ASSIGN_IR_H_ + +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// Include the auto-generated header file containing the declaration of the AIR +// dialect. +#include "maldoca/astgen/test/assign/air_dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the AIR +// interfaces. +#include "maldoca/astgen/test/assign/interfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the AIR +// types. +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/assign/air_types.h.inc" + +// Include the auto-generated header file containing the declarations of the AIR +// operations. +#define GET_OP_CLASSES +#include "maldoca/astgen/test/assign/air_ops.generated.h.inc" + +#endif // MALDOCA_ASTGEN_TEST_ASSIGN_IR_H_ diff --git a/maldoca/astgen/test/ast_gen_test_util.cc b/maldoca/astgen/test/ast_gen_test_util.cc new file mode 100644 index 0000000..4b32746 --- /dev/null +++ b/maldoca/astgen/test/ast_gen_test_util.cc @@ -0,0 +1,199 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/test/ast_gen_test_util.h" + +#include +#include + +#include "gtest/gtest.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "maldoca/astgen/ast_def.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/ast_gen.h" +#include "maldoca/base/filesystem.h" +#include "maldoca/base/get_runfiles_dir.h" +#include "maldoca/base/status_macros.h" +#include "maldoca/base/testing/status_matchers.h" + +namespace maldoca { + +absl::StatusOr AstGenTest::LoadAstDef() const { + auto ast_def_path = + GetDataDependencyFilepath(GetParam().ast_def_path); + AstDefPb ast_def_pb; + MALDOCA_RETURN_IF_ERROR(ParseTextProtoFile(ast_def_path, &ast_def_pb)); + MALDOCA_ASSIGN_OR_RETURN(AstDef ast_def, AstDef::FromProto(ast_def_pb)); + return ast_def; +} + +TEST_P(AstGenTest, PrintTsInterfaceTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + + std::string ts_interface = PrintTsInterface(ast_def); + + if (GetParam().ts_interface_path.has_value()) { + auto ts_interface_path = + GetDataDependencyFilepath(*GetParam().ts_interface_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ts_interface, + GetFileContents(ts_interface_path)); + + LOG(INFO) << "ts_interface_path: " << ts_interface_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ts_interface), + absl::StripAsciiWhitespace(expected_ts_interface)); + } +} + +TEST_P(AstGenTest, AstHdrTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + std::string ast_hdr = + PrintAstHeader(ast_def, GetParam().cc_namespace, GetParam().ast_path); + + std::cout << ast_hdr << std::endl; + + if (GetParam().expected_ast_header_path.has_value()) { + auto expected_ast_h_path = + GetDataDependencyFilepath(*GetParam().expected_ast_header_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ast_hdr, + GetFileContents(expected_ast_h_path)); + + LOG(INFO) << " expected_ast_h_path: " << expected_ast_h_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ast_hdr), + absl::StripAsciiWhitespace(expected_ast_hdr)); + } +} + +TEST_P(AstGenTest, AstSrcTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + std::string ast_src = + PrintAstSource(ast_def, GetParam().cc_namespace, GetParam().ast_path); + + std::cout << ast_src << std::endl; + + if (GetParam().expected_ast_source_path.has_value()) { + auto expected_ast_src_path = + GetDataDependencyFilepath(*GetParam().expected_ast_source_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ast_src, + GetFileContents(expected_ast_src_path)); + + LOG(INFO) << " expected_ast_src_path: " << expected_ast_src_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ast_src), + absl::StripAsciiWhitespace(expected_ast_src)); + } +} + +TEST_P(AstGenTest, AstToJsonTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + std::string ast_to_json = + PrintAstToJson(ast_def, GetParam().cc_namespace, GetParam().ast_path); + + std::cout << ast_to_json << std::endl; + + if (GetParam().expected_ast_to_json_path.has_value()) { + auto expected_ast_to_json_path = GetDataDependencyFilepath( + *GetParam().expected_ast_to_json_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ast_to_json, + GetFileContents(expected_ast_to_json_path)); + + LOG(INFO) << " expected_ast_to_json_path: " << expected_ast_to_json_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ast_to_json), + absl::StripAsciiWhitespace(expected_ast_to_json)); + } +} + +TEST_P(AstGenTest, AstFromJsonTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + std::string ast_from_json = + PrintAstFromJson(ast_def, GetParam().cc_namespace, GetParam().ast_path); + + std::cout << ast_from_json << std::endl; + + if (GetParam().expected_ast_from_json_path.has_value()) { + auto expected_ast_from_json_path = GetDataDependencyFilepath( + *GetParam().expected_ast_from_json_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ast_from_json, + GetFileContents(expected_ast_from_json_path)); + + LOG(INFO) << " expected_ast_from_json_path: " + << expected_ast_from_json_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ast_from_json), + absl::StripAsciiWhitespace(expected_ast_from_json)); + } +} + +TEST_P(AstGenTest, IrTableGenTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + std::string ir_tablegen = PrintIrTableGen(ast_def, GetParam().ir_path); + + // So that we can copy from the output. + std::cout << "Output:" << std::endl; + std::cout << ir_tablegen << std::endl; + + if (GetParam().expected_ir_tablegen_path.has_value()) { + auto expected_ir_tablegen_path = GetDataDependencyFilepath( + *GetParam().expected_ir_tablegen_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ir_tablegen, + GetFileContents(expected_ir_tablegen_path)); + + LOG(INFO) << " expected_ir_tablegen_path: " << expected_ir_tablegen_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ir_tablegen), + absl::StripAsciiWhitespace(expected_ir_tablegen)); + } +} + +TEST_P(AstGenTest, AstToIrTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + std::string ast_to_ir_source = + PrintAstToIrSource(ast_def, GetParam().cc_namespace, GetParam().ast_path, + GetParam().ir_path); + + std::cout << "Output:" << std::endl; + std::cout << ast_to_ir_source << std::endl; + + if (GetParam().expected_ast_to_ir_source_path.has_value()) { + auto cc_ast_to_ir_source_path = GetDataDependencyFilepath( + *GetParam().expected_ast_to_ir_source_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ast_to_ir_source, + GetFileContents(cc_ast_to_ir_source_path)); + + LOG(INFO) << " cc_ast_to_ir_source_path: " << cc_ast_to_ir_source_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ast_to_ir_source), + absl::StripAsciiWhitespace(expected_ast_to_ir_source)); + } +} + +TEST_P(AstGenTest, IrToAstTest) { + MALDOCA_ASSERT_OK_AND_ASSIGN(AstDef ast_def, LoadAstDef()); + std::string ir_to_ast_source = + PrintIrToAstSource(ast_def, GetParam().cc_namespace, GetParam().ast_path, + GetParam().ir_path); + + std::cout << "Output:" << std::endl; + std::cout << ir_to_ast_source << std::endl; + + if (GetParam().expected_ir_to_ast_source_path.has_value()) { + auto cc_ir_to_ast_source_path = GetDataDependencyFilepath( + *GetParam().expected_ir_to_ast_source_path); + MALDOCA_ASSERT_OK_AND_ASSIGN(std::string expected_ir_to_ast_source, + GetFileContents(cc_ir_to_ast_source_path)); + + LOG(INFO) << " cc_ast_to_ir_source_path: " << cc_ir_to_ast_source_path; + EXPECT_EQ(absl::StripAsciiWhitespace(ir_to_ast_source), + absl::StripAsciiWhitespace(expected_ir_to_ast_source)); + } +} + +} // namespace maldoca diff --git a/maldoca/astgen/test/ast_gen_test_util.h b/maldoca/astgen/test/ast_gen_test_util.h new file mode 100644 index 0000000..00f14f0 --- /dev/null +++ b/maldoca/astgen/test/ast_gen_test_util.h @@ -0,0 +1,113 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_AST_GEN_TEST_UTIL_H_ +#define MALDOCA_ASTGEN_TEST_AST_GEN_TEST_UTIL_H_ + +#include +#include + +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "maldoca/astgen/ast_def.h" + +namespace maldoca { + +// Specifies a test case for ASTGen. +struct AstGenTestParam { + // Path to the "ast.textproto" AST specification file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/ast_def.textproto" + std::string ast_def_path; + + // Path to the expected generated TypeScript interface file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/ast_ts_interface.generated" + std::optional ts_interface_path; + + // The C++ namespace for the AST classes in C++. + // Example: + // "maldoca::astgen" + std::string cc_namespace; + + // The directory for the AST code in C++. + // Inside the directory, there are the following files: + // - "ast.generated.h" + // - "ast.generated.cc" + // - "ast_to_json.generated.cc" + // - "ast_from_json.generated.cc" + std::string ast_path; + + // The directory for the IR code in TableGen and C++. + // Inside the directory, there are the following files: + // - "ir_ops.generated.td" + // - "conversion/ast_to_hir.generated.cc" + // - "conversion/hir_to_ast.generated.cc" + std::string ir_path; + + // Path to the expected "ast.generated.h" C++ header file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/ast.generated.h" + std::optional expected_ast_header_path; + + // Path to the expected "ast.generated.cc" C++ source file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/ast.generated.cc" + std::optional expected_ast_source_path; + + // Path to the expected "ast_to_json.generated.cc" C++ source file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/ast_to_json.generated.cc" + std::optional expected_ast_to_json_path; + + // Path to the expected "ast_from_json.generated.cc" C++ source file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/ast_from_json.generated.cc" + std::optional expected_ast_from_json_path; + + // Path to the expected "ir_ops.generated.td" TableGen source file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/lambdair_ops.generated.td" + std::optional expected_ir_tablegen_path; + + // Path to the expected "ast_to_ir.generated.cc" C++ source file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/ast_to_lambdair.generated.cc" + std::optional expected_ast_to_ir_source_path; + + // Path to the expected "ir_to_ast.generated.cc" C++ source file. + // Should start with "google3/". + // Example: + // "maldoca/astgen/test/lambda/lambdair_to_ast.generated.cc" + std::optional expected_ir_to_ast_source_path; +}; + +class AstGenTest : public ::testing::TestWithParam { + protected: + absl::StatusOr LoadAstDef() const; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_AST_GEN_TEST_UTIL_H_ diff --git a/maldoca/astgen/test/conversion_test_util.h b/maldoca/astgen/test/conversion_test_util.h new file mode 100644 index 0000000..fd80528 --- /dev/null +++ b/maldoca/astgen/test/conversion_test_util.h @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_CONVERSION_TEST_UTIL_H_ +#define MALDOCA_ASTGEN_TEST_CONVERSION_TEST_UTIL_H_ + +#include +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/testing/status_matchers.h" + +namespace maldoca { + +class DummyIrToAst {}; + +template +struct ConversionTestCase { + std::string ast_json_string; + std::unique_ptr ast; + OpT (AstToIr::*ast_to_ir_visit)(const NodeT *); + absl::StatusOr> (IrToAst::*ir_to_ast_visit)(OpT); + std::string expected_ir_dump; +}; + +template +void TestIrConversion( + ConversionTestCase &&test_case) { + if (!test_case.ast_json_string.empty()) { + auto ast_json = nlohmann::json::parse(test_case.ast_json_string, + /*callback=*/nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/false); + ASSERT_FALSE(ast_json.is_discarded()) + << "Failed to parse AST: Invalid JSON."; + + MALDOCA_ASSERT_OK_AND_ASSIGN(test_case.ast, NodeT::FromJson(ast_json)); + } + + mlir::MLIRContext context; + context.getOrLoadDialect(); + mlir::OpBuilder builder(&context); + + // A file is modeled as a "module" in MLIR. + mlir::OwningOpRef module = + mlir::ModuleOp::create(builder.getUnknownLoc()); + + mlir::Block *block = &module->getBodyRegion().front(); + builder.setInsertionPointToStart(block); + + AstToIr ast_to_ir(builder); + OpT op = (ast_to_ir.*(test_case.ast_to_ir_visit))(test_case.ast.get()); + + std::string ir_dump; + llvm::raw_string_ostream os(ir_dump); + module->print(os); + + EXPECT_EQ(absl::StripAsciiWhitespace(ir_dump), + absl::StripAsciiWhitespace(test_case.expected_ir_dump)); + + if (test_case.ir_to_ast_visit != nullptr) { + IrToAst ir_to_ast; + MALDOCA_ASSERT_OK_AND_ASSIGN(auto raised_ast, + (ir_to_ast.*(test_case.ir_to_ast_visit))(op)); + + std::stringstream test_case_ss; + test_case.ast->Serialize(test_case_ss); + + std::stringstream raised_ast_ss; + raised_ast->Serialize(raised_ast_ss); + + EXPECT_EQ(test_case_ss.str(), raised_ast_ss.str()); + } +} + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_CONVERSION_TEST_UTIL_H_ diff --git a/maldoca/astgen/test/enum/BUILD b/maldoca/astgen/test/enum/BUILD new file mode 100644 index 0000000..deed21b --- /dev/null +++ b/maldoca/astgen/test/enum/BUILD @@ -0,0 +1,186 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//maldoca/astgen:__subpackages__", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + "eir_ops.generated.td", + "//maldoca/astgen/test/enum/conversion:ast_to_eir.generated.cc", + "//maldoca/astgen/test/enum/conversion:eir_to_ast.generated.cc", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = [ + "ast.generated.h", + ], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) + +td_library( + name = "interfaces_td_files", + srcs = [ + "interfaces.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "interfaces_inc_gen", + tbl_outs = { + "interfaces.h.inc": ["-gen-op-interface-decls"], + "interfaces.cc.inc": ["-gen-op-interface-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "interfaces.td", + deps = [":interfaces_td_files"], +) + +td_library( + name = "eir_dialect_td_files", + srcs = [ + "eir_dialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "eir_dialect_inc_gen", + tbl_outs = { + "eir_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=eir", + ], + "eir_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=eir", + ], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "eir_dialect.td", + deps = [":eir_dialect_td_files"], +) + +td_library( + name = "eir_types_td_files", + srcs = [ + "eir_types.td", + ], + deps = [ + ":eir_dialect_td_files", + ], +) + +gentbl_cc_library( + name = "eir_types_inc_gen", + tbl_outs = { + "eir_types.h.inc": ["-gen-typedef-decls"], + "eir_types.cc.inc": ["-gen-typedef-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "eir_types.td", + deps = [":eir_types_td_files"], +) + +td_library( + name = "eir_ops_generated_td_files", + srcs = [ + "eir_ops.generated.td", + ], + deps = [ + ":eir_dialect_td_files", + ":eir_types_td_files", + ":interfaces_td_files", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "eir_ops_generated_inc_gen", + tbl_outs = { + "eir_ops.generated.h.inc": ["-gen-op-decls"], + "eir_ops.generated.cc.inc": ["-gen-op-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "eir_ops.generated.td", + deps = [":eir_ops_generated_td_files"], +) + +cc_library( + name = "ir", + srcs = ["ir.cc"], + hdrs = ["ir.h"], + deps = [ + ":eir_dialect_inc_gen", + ":eir_ops_generated_inc_gen", + ":eir_types_inc_gen", + ":interfaces_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + ], +) diff --git a/maldoca/astgen/test/enum/ast.generated.cc b/maldoca/astgen/test/enum/ast.generated.cc new file mode 100644 index 0000000..24eba66 --- /dev/null +++ b/maldoca/astgen/test/enum/ast.generated.cc @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/enum/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +absl::string_view EUnaryOperatorToString(EUnaryOperator unary_operator) { + switch (unary_operator) { + case EUnaryOperator::kMinus: + return "-"; + case EUnaryOperator::kPlus: + return "+"; + case EUnaryOperator::kNot: + return "!"; + case EUnaryOperator::kBitwiseNot: + return "~"; + case EUnaryOperator::kTypeOf: + return "typeof"; + case EUnaryOperator::kVoid: + return "void"; + case EUnaryOperator::kDelete: + return "delete"; + case EUnaryOperator::kThrow: + return "throw"; + } +} + +absl::StatusOr StringToEUnaryOperator(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"-", EUnaryOperator::kMinus}, + {"+", EUnaryOperator::kPlus}, + {"!", EUnaryOperator::kNot}, + {"~", EUnaryOperator::kBitwiseNot}, + {"typeof", EUnaryOperator::kTypeOf}, + {"void", EUnaryOperator::kVoid}, + {"delete", EUnaryOperator::kDelete}, + {"throw", EUnaryOperator::kThrow}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for EUnaryOperator: ", s)); + } + return it->second; +} + +absl::string_view EEscapedCharToString(EEscapedChar escaped_char) { + switch (escaped_char) { + case EEscapedChar::kTab: + return "\t"; + case EEscapedChar::kBackslash: + return "\\"; + } +} + +absl::StatusOr StringToEEscapedChar(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"\t", EEscapedChar::kTab}, + {"\\", EEscapedChar::kBackslash}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for EEscapedChar: ", s)); + } + return it->second; +} + +// ============================================================================= +// ENode +// ============================================================================= + +ENode::ENode( + EUnaryOperator unary_operator, + EEscapedChar escaped_char) + : unary_operator_(std::move(unary_operator)), + escaped_char_(std::move(escaped_char)) {} + +EUnaryOperator ENode::unary_operator() const { + return unary_operator_; +} + +void ENode::set_unary_operator(EUnaryOperator unary_operator) { + unary_operator_ = unary_operator; +} + +EEscapedChar ENode::escaped_char() const { + return escaped_char_; +} + +void ENode::set_escaped_char(EEscapedChar escaped_char) { + escaped_char_ = escaped_char; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/enum/ast.generated.h b/maldoca/astgen/test/enum/ast.generated.h new file mode 100644 index 0000000..499ee84 --- /dev/null +++ b/maldoca/astgen/test/enum/ast.generated.h @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_ENUM_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_ENUM_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +enum class EUnaryOperator { + kMinus, + kPlus, + kNot, + kBitwiseNot, + kTypeOf, + kVoid, + kDelete, + kThrow, +}; + +absl::string_view EUnaryOperatorToString(EUnaryOperator unary_operator); +absl::StatusOr StringToEUnaryOperator(absl::string_view s); + +enum class EEscapedChar { + kTab, + kBackslash, +}; + +absl::string_view EEscapedCharToString(EEscapedChar escaped_char); +absl::StatusOr StringToEEscapedChar(absl::string_view s); + +class ENode { + public: + explicit ENode( + EUnaryOperator unary_operator, + EEscapedChar escaped_char); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + EUnaryOperator unary_operator() const; + void set_unary_operator(EUnaryOperator unary_operator); + + EEscapedChar escaped_char() const; + void set_escaped_char(EEscapedChar escaped_char); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetUnaryOperator(const nlohmann::json& json); + static absl::StatusOr GetEscapedChar(const nlohmann::json& json); + + private: + EUnaryOperator unary_operator_; + EEscapedChar escaped_char_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_ENUM_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/enum/ast_def.textproto b/maldoca/astgen/test/enum/ast_def.textproto new file mode 100644 index 0000000..dbac974 --- /dev/null +++ b/maldoca/astgen/test/enum/ast_def.textproto @@ -0,0 +1,71 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "e" + +# enum UnaryOperator { +# "-" | "+" | "!" | "~" | "typeof" | "void" | "delete" | "throw" +# } +enums { + name: "UnaryOperator" + members { + name: "Minus" + string_value: "-" + } + members { + name: "Plus" + string_value: "+" + } + members { + name: "Not" + string_value: "!" + } + members { + name: "BitwiseNot" + string_value: "~" + } + members { + name: "TypeOf" + string_value: "typeof" + } + members { + name: "Void" + string_value: "void" + } + members { + name: "Delete" + string_value: "delete" + } + members { + name: "Throw" + string_value: "throw" + } +} + +enums { + name: "EscapedChar" + members: { + name: "Tab" + string_value: "\t" + } + members: { + name: "Backslash" + string_value: "\\" + } +} + +nodes { + name: "Node" + fields { + name: "unaryOperator" + type { enum: "UnaryOperator" } + kind: FIELD_KIND_ATTR + } + fields { + name: "escapedChar" + type { enum: "EscapedChar" } + kind: FIELD_KIND_ATTR + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} diff --git a/maldoca/astgen/test/enum/ast_from_json.generated.cc b/maldoca/astgen/test/enum/ast_from_json.generated.cc new file mode 100644 index 0000000..f5f18ea --- /dev/null +++ b/maldoca/astgen/test/enum/ast_from_json.generated.cc @@ -0,0 +1,101 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/enum/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +// ============================================================================= +// ENode +// ============================================================================= + +absl::StatusOr +ENode::GetUnaryOperator(const nlohmann::json& json) { + auto unary_operator_it = json.find("unaryOperator"); + if (unary_operator_it == json.end()) { + return absl::InvalidArgumentError("`unaryOperator` is undefined."); + } + const nlohmann::json& json_unary_operator = unary_operator_it.value(); + + if (json_unary_operator.is_null()) { + return absl::InvalidArgumentError("json_unary_operator is null."); + } + if (!json_unary_operator.is_string()) { + return absl::InvalidArgumentError("`json_unary_operator` expected to be a string."); + } + std::string json_unary_operator_str = json_unary_operator.get(); + return StringToEUnaryOperator(json_unary_operator_str); +} + +absl::StatusOr +ENode::GetEscapedChar(const nlohmann::json& json) { + auto escaped_char_it = json.find("escapedChar"); + if (escaped_char_it == json.end()) { + return absl::InvalidArgumentError("`escapedChar` is undefined."); + } + const nlohmann::json& json_escaped_char = escaped_char_it.value(); + + if (json_escaped_char.is_null()) { + return absl::InvalidArgumentError("json_escaped_char is null."); + } + if (!json_escaped_char.is_string()) { + return absl::InvalidArgumentError("`json_escaped_char` expected to be a string."); + } + std::string json_escaped_char_str = json_escaped_char.get(); + return StringToEEscapedChar(json_escaped_char_str); +} + +absl::StatusOr> +ENode::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto unary_operator, ENode::GetUnaryOperator(json)); + MALDOCA_ASSIGN_OR_RETURN(auto escaped_char, ENode::GetEscapedChar(json)); + + return absl::make_unique( + std::move(unary_operator), + std::move(escaped_char)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/enum/ast_gen_test.cc b/maldoca/astgen/test/enum/ast_gen_test.cc new file mode 100644 index 0000000..3638cde --- /dev/null +++ b/maldoca/astgen/test/enum/ast_gen_test.cc @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + Lambda, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = + "maldoca/astgen/test/enum/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "enum/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/enum", + .ir_path = "maldoca/astgen/test/enum", + .expected_ast_header_path = + "maldoca/astgen/test/enum/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/enum/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "enum/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "enum/ast_from_json.generated.cc", + .expected_ir_tablegen_path = + "maldoca/astgen/test/enum/eir_ops.generated.td", + .expected_ast_to_ir_source_path = + "maldoca/astgen/test/enum/conversion/" + "ast_to_eir.generated.cc", + .expected_ir_to_ast_source_path = + "maldoca/astgen/test/enum/conversion/" + "eir_to_ast.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/enum/ast_to_json.generated.cc b/maldoca/astgen/test/enum/ast_to_json.generated.cc new file mode 100644 index 0000000..a3dc087 --- /dev/null +++ b/maldoca/astgen/test/enum/ast_to_json.generated.cc @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/enum/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// ENode +// ============================================================================= + +void ENode::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"unaryOperator\":" << "\"" << EUnaryOperatorToString(unary_operator_) << "\""; + MaybeAddComma(os, needs_comma); + os << "\"escapedChar\":" << "\"" << EEscapedCharToString(escaped_char_) << "\""; +} + +void ENode::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + ENode::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/enum/ast_ts_interface.generated b/maldoca/astgen/test/enum/ast_ts_interface.generated new file mode 100644 index 0000000..e083b32 --- /dev/null +++ b/maldoca/astgen/test/enum/ast_ts_interface.generated @@ -0,0 +1,18 @@ +type UnaryOperator = + | "-" + | "+" + | "!" + | "~" + | "typeof" + | "void" + | "delete" + | "throw" + +type EscapedChar = + | "\t" + | "\\" + +interface Node { + unaryOperator: UnaryOperator + escapedChar: EscapedChar +} diff --git a/maldoca/astgen/test/enum/conversion/BUILD b/maldoca/astgen/test/enum/conversion/BUILD new file mode 100644 index 0000000..cf1409b --- /dev/null +++ b/maldoca/astgen/test/enum/conversion/BUILD @@ -0,0 +1,76 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_applicable_licenses = ["//:license"]) + +licenses(["notice"]) + +exports_files([ + "ast_to_eir.generated.cc", + "eir_to_ast.generated.cc", +]) + +cc_library( + name = "ast_to_eir", + srcs = ["ast_to_eir.generated.cc"], + hdrs = ["ast_to_eir.h"], + deps = [ + "//maldoca/astgen/test/enum:ast", + "//maldoca/astgen/test/enum:ir", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "eir_to_ast", + srcs = ["eir_to_ast.generated.cc"], + hdrs = ["eir_to_ast.h"], + deps = [ + "//maldoca/astgen/test/enum:ast", + "//maldoca/astgen/test/enum:ir", + "//maldoca/base:status", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_test( + name = "conversion_test", + srcs = ["conversion_test.cc"], + deps = [ + ":ast_to_eir", + ":eir_to_ast", + "//maldoca/astgen/test:conversion_test_util", + "//maldoca/astgen/test/enum:ast", + "//maldoca/astgen/test/enum:ir", + "@googletest//:gtest_main", + ], +) diff --git a/maldoca/astgen/test/enum/conversion/ast_to_eir.generated.cc b/maldoca/astgen/test/enum/conversion/ast_to_eir.generated.cc new file mode 100644 index 0000000..18aba2a --- /dev/null +++ b/maldoca/astgen/test/enum/conversion/ast_to_eir.generated.cc @@ -0,0 +1,58 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/enum/conversion/ast_to_eir.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/astgen/test/enum/ast.generated.h" +#include "maldoca/astgen/test/enum/ir.h" + +namespace maldoca { + +EirNodeOp AstToEir::VisitNode(const ENode *node) { + mlir::StringAttr mlir_unary_operator = builder_.getStringAttr(EUnaryOperatorToString(node->unary_operator())); + mlir::StringAttr mlir_escaped_char = builder_.getStringAttr(EEscapedCharToString(node->escaped_char())); + return CreateExpr(node, mlir_unary_operator, mlir_escaped_char); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/enum/conversion/ast_to_eir.h b/maldoca/astgen/test/enum/conversion/ast_to_eir.h new file mode 100644 index 0000000..9f6d6c4 --- /dev/null +++ b/maldoca/astgen/test/enum/conversion/ast_to_eir.h @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ENUM_CONVERSION_AST_TO_EIR_H_ +#define MALDOCA_ASTGEN_TEST_ENUM_CONVERSION_AST_TO_EIR_H_ + +#include "mlir/IR/Builders.h" +#include "maldoca/astgen/test/enum/ast.generated.h" +#include "maldoca/astgen/test/enum/ir.h" + +namespace maldoca { + +class AstToEir { + public: + explicit AstToEir(mlir::OpBuilder &builder) : builder_(builder) {} + + EirNodeOp VisitNode(const ENode *node); + + private: + template + Op CreateExpr(const Node *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), + EirAnyType::get(builder_.getContext()), + std::forward(args)...); + } + + mlir::OpBuilder &builder_; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_ENUM_CONVERSION_AST_TO_EIR_H_ diff --git a/maldoca/astgen/test/enum/conversion/conversion_test.cc b/maldoca/astgen/test/enum/conversion/conversion_test.cc new file mode 100644 index 0000000..7ad6ad2 --- /dev/null +++ b/maldoca/astgen/test/enum/conversion/conversion_test.cc @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/conversion_test_util.h" +#include "maldoca/astgen/test/enum/ast.generated.h" +#include "maldoca/astgen/test/enum/conversion/ast_to_eir.h" +#include "maldoca/astgen/test/enum/conversion/eir_to_ast.h" +#include "maldoca/astgen/test/enum/ir.h" + +namespace maldoca { +namespace { + +TEST(ConversionTest, Enum) { + constexpr char kAstJsonString[] = R"( + { + "unaryOperator": "+", + "escapedChar": "\\" + } + )"; + + constexpr char kExpectedIrDump[] = R"( +module { + %0 = "eir.node"() <{escaped_char = "\\", unary_operator = "+"}> : () -> !eir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToEir::VisitNode, + .ir_to_ast_visit = &EirToAst::VisitNode, + .expected_ir_dump = kExpectedIrDump, + }); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/enum/conversion/eir_to_ast.generated.cc b/maldoca/astgen/test/enum/conversion/eir_to_ast.generated.cc new file mode 100644 index 0000000..b4e7630 --- /dev/null +++ b/maldoca/astgen/test/enum/conversion/eir_to_ast.generated.cc @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/enum/conversion/eir_to_ast.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/base/status_macros.h" +#include "maldoca/astgen/test/enum/ast.generated.h" +#include "maldoca/astgen/test/enum/ir.h" + +namespace maldoca { + +absl::StatusOr> +EirToAst::VisitNode(EirNodeOp op) { + MALDOCA_ASSIGN_OR_RETURN(EUnaryOperator unary_operator, StringToEUnaryOperator(op.getUnaryOperatorAttr().str())); + MALDOCA_ASSIGN_OR_RETURN(EEscapedChar escaped_char, StringToEEscapedChar(op.getEscapedCharAttr().str())); + return Create( + op, + std::move(unary_operator), + std::move(escaped_char)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/enum/conversion/eir_to_ast.h b/maldoca/astgen/test/enum/conversion/eir_to_ast.h new file mode 100644 index 0000000..0a45e5f --- /dev/null +++ b/maldoca/astgen/test/enum/conversion/eir_to_ast.h @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ENUM_CONVERSION_EIR_TO_AST_H_ +#define MALDOCA_ASTGEN_TEST_ENUM_CONVERSION_EIR_TO_AST_H_ + +#include + +#include "mlir/IR/Operation.h" +#include "absl/status/statusor.h" +#include "maldoca/astgen/test/enum/ast.generated.h" +#include "maldoca/astgen/test/enum/ir.h" + +namespace maldoca { + +class EirToAst { + public: + absl::StatusOr> VisitNode(EirNodeOp op); + + template + std::unique_ptr Create(mlir::Operation *op, Args &&...args) { + return absl::make_unique(std::forward(args)...); + } +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_ENUM_CONVERSION_EIR_TO_AST_H_ diff --git a/maldoca/astgen/test/enum/eir_dialect.td b/maldoca/astgen/test/enum/eir_dialect.td new file mode 100644 index 0000000..e35b09f --- /dev/null +++ b/maldoca/astgen/test/enum/eir_dialect.td @@ -0,0 +1,42 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ENUM_EIR_DIALECT_TD_ +#define MALDOCA_ASTGEN_TEST_ENUM_EIR_DIALECT_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def Eir_Dialect : Dialect { + let name = "eir"; + let cppNamespace = "::maldoca"; + + let description = [{ + The EnumIR, a test IR that models enums. All ops and fields are + directly mapped from the AST. + }]; + + let useDefaultTypePrinterParser = 1; +} + +class Eir_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = ?; +} + +class Eir_Op traits = []> : + Op; + +#endif // MALDOCA_ASTGEN_TEST_ENUM_EIR_DIALECT_TD_ diff --git a/maldoca/astgen/test/enum/eir_ops.generated.td b/maldoca/astgen/test/enum/eir_ops.generated.td new file mode 100644 index 0000000..e1fb303 --- /dev/null +++ b/maldoca/astgen/test/enum/eir_ops.generated.td @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_ENUM_EIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_ENUM_EIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/enum/interfaces.td" +include "maldoca/astgen/test/enum/eir_dialect.td" +include "maldoca/astgen/test/enum/eir_types.td" + +def EirNodeOp : Eir_Op<"node", []> { + let arguments = (ins + StrAttr: $unary_operator, + StrAttr: $escaped_char + ); + + let results = (outs + EirAnyType + ); +} + +#endif // MALDOCA_ASTGEN_TEST_ENUM_EIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/enum/eir_types.td b/maldoca/astgen/test/enum/eir_types.td new file mode 100644 index 0000000..3774aea --- /dev/null +++ b/maldoca/astgen/test/enum/eir_types.td @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ENUM_EIR_TYPES_TD_ +#define MALDOCA_ASTGEN_TEST_ENUM_EIR_TYPES_TD_ + +include "maldoca/astgen/test/enum/eir_dialect.td" + +def EirAnyType : Eir_Type<"EirAny"> { + let summary = "A placeholder singleton type."; + let mnemonic = "any"; + let assemblyFormat = ""; +} + +#endif // MALDOCA_ASTGEN_TEST_ENUM_EIR_TYPES_TD_ diff --git a/maldoca/astgen/test/enum/interfaces.td b/maldoca/astgen/test/enum/interfaces.td new file mode 100644 index 0000000..5d18117 --- /dev/null +++ b/maldoca/astgen/test/enum/interfaces.td @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "mlir/IR/OpBase.td" diff --git a/maldoca/astgen/test/enum/ir.cc b/maldoca/astgen/test/enum/ir.cc new file mode 100644 index 0000000..d491c93 --- /dev/null +++ b/maldoca/astgen/test/enum/ir.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/test/enum/ir.h" + +// IWYU pragma: begin_keep + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +// IWYU pragma: end_keep + +// ============================================================================= +// Dialect Definition +// ============================================================================= + +#include "maldoca/astgen/test/enum/eir_dialect.cc.inc" + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void maldoca::EirDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "maldoca/astgen/test/enum/eir_types.cc.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "maldoca/astgen/test/enum/eir_ops.generated.cc.inc" + >(); +} + +// ============================================================================= +// Dialect Interface Definitions +// ============================================================================= + +#include "maldoca/astgen/test/enum/interfaces.cc.inc" + +// ============================================================================= +// Dialect Type Definitions +// ============================================================================= + +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/enum/eir_types.cc.inc" + +// ============================================================================= +// Dialect Op Definitions +// ============================================================================= + +#define GET_OP_CLASSES +#include "maldoca/astgen/test/enum/eir_ops.generated.cc.inc" diff --git a/maldoca/astgen/test/enum/ir.h b/maldoca/astgen/test/enum/ir.h new file mode 100644 index 0000000..f3bfa80 --- /dev/null +++ b/maldoca/astgen/test/enum/ir.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_ENUM_IR_H_ +#define MALDOCA_ASTGEN_TEST_ENUM_IR_H_ + +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// Include the auto-generated header file containing the declaration of the EIR +// dialect. +#include "maldoca/astgen/test/enum/eir_dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the EIR +// interfaces. +#include "maldoca/astgen/test/enum/interfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the EIR +// types. +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/enum/eir_types.h.inc" + +// Include the auto-generated header file containing the declarations of the EIR +// operations. +#define GET_OP_CLASSES +#include "maldoca/astgen/test/enum/eir_ops.generated.h.inc" + +#endif // MALDOCA_ASTGEN_TEST_ENUM_IR_H_ diff --git a/maldoca/astgen/test/lambda/BUILD b/maldoca/astgen/test/lambda/BUILD new file mode 100644 index 0000000..9023956 --- /dev/null +++ b/maldoca/astgen/test/lambda/BUILD @@ -0,0 +1,179 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//maldoca/astgen:__subpackages__", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + "lair_ops.generated.td", + "//maldoca/astgen/test/lambda/conversion:ast_to_lair.generated.cc", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = ["ast.generated.h"], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) + +td_library( + name = "interfaces_td_files", + srcs = [ + "interfaces.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "interfaces_inc_gen", + tbl_outs = { + "interfaces.h.inc": ["-gen-op-interface-decls"], + "interfaces.cc.inc": ["-gen-op-interface-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "interfaces.td", + deps = [":interfaces_td_files"], +) + +td_library( + name = "lair_dialect_td_files", + srcs = [ + "lair_dialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "lair_dialect_inc_gen", + tbl_outs = { + "lair_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=lair", + ], + "lair_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=lair", + ], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lair_dialect.td", + deps = [":lair_dialect_td_files"], +) + +td_library( + name = "lair_types_td_files", + srcs = [ + "lair_types.td", + ], + deps = [ + ":lair_dialect_td_files", + ], +) + +gentbl_cc_library( + name = "lair_types_inc_gen", + tbl_outs = { + "lair_types.h.inc": ["-gen-typedef-decls"], + "lair_types.cc.inc": ["-gen-typedef-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lair_types.td", + deps = [":lair_types_td_files"], +) + +td_library( + name = "lair_ops_td_files", + srcs = [ + "lair_ops.generated.td", + "lair_ops.td", + ], + deps = [ + ":interfaces_td_files", + ":lair_dialect_td_files", + ":lair_types_td_files", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "lair_ops_inc_gen", + tbl_outs = { + "lair_ops.h.inc": ["-gen-op-decls"], + "lair_ops.cc.inc": ["-gen-op-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lair_ops.td", + deps = [":lair_ops_td_files"], +) + +cc_library( + name = "ir", + srcs = ["ir.cc"], + hdrs = ["ir.h"], + deps = [ + ":interfaces_inc_gen", + ":lair_dialect_inc_gen", + ":lair_ops_inc_gen", + ":lair_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) diff --git a/maldoca/astgen/test/lambda/ast.generated.cc b/maldoca/astgen/test/lambda/ast.generated.cc new file mode 100644 index 0000000..221ab1e --- /dev/null +++ b/maldoca/astgen/test/lambda/ast.generated.cc @@ -0,0 +1,165 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/lambda/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// LaExpression +// ============================================================================= + +absl::string_view LaExpressionTypeToString(LaExpressionType expression_type) { + switch (expression_type) { + case LaExpressionType::kVariable: + return "Variable"; + case LaExpressionType::kFunctionDefinition: + return "FunctionDefinition"; + case LaExpressionType::kFunctionCall: + return "FunctionCall"; + } +} + +absl::StatusOr StringToLaExpressionType(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"Variable", LaExpressionType::kVariable}, + {"FunctionDefinition", LaExpressionType::kFunctionDefinition}, + {"FunctionCall", LaExpressionType::kFunctionCall}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for LaExpressionType: ", s)); + } + return it->second; +} + +// ============================================================================= +// LaVariable +// ============================================================================= + +LaVariable::LaVariable( + std::string identifier) + : LaExpression(), + identifier_(std::move(identifier)) {} + +absl::string_view LaVariable::identifier() const { + return identifier_; +} + +void LaVariable::set_identifier(std::string identifier) { + identifier_ = std::move(identifier); +} + +// ============================================================================= +// LaFunctionDefinition +// ============================================================================= + +LaFunctionDefinition::LaFunctionDefinition( + std::unique_ptr parameter, + std::unique_ptr body) + : LaExpression(), + parameter_(std::move(parameter)), + body_(std::move(body)) {} + +LaVariable* LaFunctionDefinition::parameter() { + return parameter_.get(); +} + +const LaVariable* LaFunctionDefinition::parameter() const { + return parameter_.get(); +} + +void LaFunctionDefinition::set_parameter(std::unique_ptr parameter) { + parameter_ = std::move(parameter); +} + +LaExpression* LaFunctionDefinition::body() { + return body_.get(); +} + +const LaExpression* LaFunctionDefinition::body() const { + return body_.get(); +} + +void LaFunctionDefinition::set_body(std::unique_ptr body) { + body_ = std::move(body); +} + +// ============================================================================= +// LaFunctionCall +// ============================================================================= + +LaFunctionCall::LaFunctionCall( + std::unique_ptr function, + std::unique_ptr argument) + : LaExpression(), + function_(std::move(function)), + argument_(std::move(argument)) {} + +LaExpression* LaFunctionCall::function() { + return function_.get(); +} + +const LaExpression* LaFunctionCall::function() const { + return function_.get(); +} + +void LaFunctionCall::set_function(std::unique_ptr function) { + function_ = std::move(function); +} + +LaExpression* LaFunctionCall::argument() { + return argument_.get(); +} + +const LaExpression* LaFunctionCall::argument() const { + return argument_.get(); +} + +void LaFunctionCall::set_argument(std::unique_ptr argument) { + argument_ = std::move(argument); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/ast.generated.h b/maldoca/astgen/test/lambda/ast.generated.h new file mode 100644 index 0000000..19439cd --- /dev/null +++ b/maldoca/astgen/test/lambda/ast.generated.h @@ -0,0 +1,175 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_LAMBDA_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_LAMBDA_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +enum class LaExpressionType { + kVariable, + kFunctionDefinition, + kFunctionCall, +}; + +absl::string_view LaExpressionTypeToString(LaExpressionType expression_type); +absl::StatusOr StringToLaExpressionType(absl::string_view s); + +class LaExpression { + public: + virtual ~LaExpression() = default; + + virtual LaExpressionType expression_type() const = 0; + + virtual void Serialize(std::ostream& os) const = 0; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class LaVariable : public virtual LaExpression { + public: + explicit LaVariable( + std::string identifier); + + LaExpressionType expression_type() const override { + return LaExpressionType::kVariable; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + absl::string_view identifier() const; + void set_identifier(std::string identifier); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetIdentifier(const nlohmann::json& json); + + private: + std::string identifier_; +}; + +class LaFunctionDefinition : public virtual LaExpression { + public: + explicit LaFunctionDefinition( + std::unique_ptr parameter, + std::unique_ptr body); + + LaExpressionType expression_type() const override { + return LaExpressionType::kFunctionDefinition; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + LaVariable* parameter(); + const LaVariable* parameter() const; + void set_parameter(std::unique_ptr parameter); + + LaExpression* body(); + const LaExpression* body() const; + void set_body(std::unique_ptr body); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetParameter(const nlohmann::json& json); + static absl::StatusOr> GetBody(const nlohmann::json& json); + + private: + std::unique_ptr parameter_; + std::unique_ptr body_; +}; + +class LaFunctionCall : public virtual LaExpression { + public: + explicit LaFunctionCall( + std::unique_ptr function, + std::unique_ptr argument); + + LaExpressionType expression_type() const override { + return LaExpressionType::kFunctionCall; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + LaExpression* function(); + const LaExpression* function() const; + void set_function(std::unique_ptr function); + + LaExpression* argument(); + const LaExpression* argument() const; + void set_argument(std::unique_ptr argument); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetFunction(const nlohmann::json& json); + static absl::StatusOr> GetArgument(const nlohmann::json& json); + + private: + std::unique_ptr function_; + std::unique_ptr argument_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_LAMBDA_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/lambda/ast_def.textproto b/maldoca/astgen/test/lambda/ast_def.textproto new file mode 100644 index 0000000..bd981a5 --- /dev/null +++ b/maldoca/astgen/test/lambda/ast_def.textproto @@ -0,0 +1,75 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "la" + +# interface Expression { +# type: string +# } +nodes { + name: "Expression" + kinds: FIELD_KIND_RVAL +} + +# interface Variable <: Expression { +# type: "Variable" +# identifier: string +# } +nodes { + name: "Variable" + type: "Variable" + parents: "Expression" + fields { + name: "identifier" + type { string {} } + kind: FIELD_KIND_ATTR + } + kinds: FIELD_KIND_LVAL + should_generate_ir_op: true +} + +# interface FunctionDefinition <: Expression { +# type: "FunctionDefinition" +# parameter: Variable +# body: Expression +# } +nodes { + name: "FunctionDefinition" + type: "FunctionDefinition" + parents: "Expression" + fields { + name: "parameter" + type { class: "Variable" } + kind: FIELD_KIND_LVAL + enclose_in_region: true + } + fields { + name: "body" + type { class: "Expression" } + kind: FIELD_KIND_RVAL + enclose_in_region: true + } + should_generate_ir_op: true +} + +# interface FunctionCall <: Expression { +# type: "FunctionCall" +# function: Expression +# argument: Expression +# } +nodes { + name: "FunctionCall" + type: "FunctionCall" + parents: "Expression" + fields { + name: "function" + type { class: "Expression" } + kind: FIELD_KIND_RVAL + } + fields { + name: "argument" + type { class: "Expression" } + kind: FIELD_KIND_RVAL + } + should_generate_ir_op: true +} diff --git a/maldoca/astgen/test/lambda/ast_from_json.generated.cc b/maldoca/astgen/test/lambda/ast_from_json.generated.cc new file mode 100644 index 0000000..1e2902e --- /dev/null +++ b/maldoca/astgen/test/lambda/ast_from_json.generated.cc @@ -0,0 +1,209 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/lambda/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +static absl::StatusOr GetType(const nlohmann::json& json) { + auto type_it = json.find("type"); + if (type_it == json.end()) { + return absl::InvalidArgumentError("`type` is undefined."); + } + const nlohmann::json& json_type = type_it.value(); + if (json_type.is_null()) { + return absl::InvalidArgumentError("json_type is null."); + } + if (!json_type.is_string()) { + return absl::InvalidArgumentError("`json_type` expected to be string."); + } + return json_type.get(); +} + +// ============================================================================= +// LaExpression +// ============================================================================= + +absl::StatusOr> +LaExpression::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "Variable") { + return LaVariable::FromJson(json); + } else if (type == "FunctionDefinition") { + return LaFunctionDefinition::FromJson(json); + } else if (type == "FunctionCall") { + return LaFunctionCall::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// LaVariable +// ============================================================================= + +absl::StatusOr +LaVariable::GetIdentifier(const nlohmann::json& json) { + auto identifier_it = json.find("identifier"); + if (identifier_it == json.end()) { + return absl::InvalidArgumentError("`identifier` is undefined."); + } + const nlohmann::json& json_identifier = identifier_it.value(); + + if (json_identifier.is_null()) { + return absl::InvalidArgumentError("json_identifier is null."); + } + if (!json_identifier.is_string()) { + return absl::InvalidArgumentError("Expecting json_identifier.is_string()."); + } + return json_identifier.get(); +} + +absl::StatusOr> +LaVariable::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto identifier, LaVariable::GetIdentifier(json)); + + return absl::make_unique( + std::move(identifier)); +} + +// ============================================================================= +// LaFunctionDefinition +// ============================================================================= + +absl::StatusOr> +LaFunctionDefinition::GetParameter(const nlohmann::json& json) { + auto parameter_it = json.find("parameter"); + if (parameter_it == json.end()) { + return absl::InvalidArgumentError("`parameter` is undefined."); + } + const nlohmann::json& json_parameter = parameter_it.value(); + + if (json_parameter.is_null()) { + return absl::InvalidArgumentError("json_parameter is null."); + } + return LaVariable::FromJson(json_parameter); +} + +absl::StatusOr> +LaFunctionDefinition::GetBody(const nlohmann::json& json) { + auto body_it = json.find("body"); + if (body_it == json.end()) { + return absl::InvalidArgumentError("`body` is undefined."); + } + const nlohmann::json& json_body = body_it.value(); + + if (json_body.is_null()) { + return absl::InvalidArgumentError("json_body is null."); + } + return LaExpression::FromJson(json_body); +} + +absl::StatusOr> +LaFunctionDefinition::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto parameter, LaFunctionDefinition::GetParameter(json)); + MALDOCA_ASSIGN_OR_RETURN(auto body, LaFunctionDefinition::GetBody(json)); + + return absl::make_unique( + std::move(parameter), + std::move(body)); +} + +// ============================================================================= +// LaFunctionCall +// ============================================================================= + +absl::StatusOr> +LaFunctionCall::GetFunction(const nlohmann::json& json) { + auto function_it = json.find("function"); + if (function_it == json.end()) { + return absl::InvalidArgumentError("`function` is undefined."); + } + const nlohmann::json& json_function = function_it.value(); + + if (json_function.is_null()) { + return absl::InvalidArgumentError("json_function is null."); + } + return LaExpression::FromJson(json_function); +} + +absl::StatusOr> +LaFunctionCall::GetArgument(const nlohmann::json& json) { + auto argument_it = json.find("argument"); + if (argument_it == json.end()) { + return absl::InvalidArgumentError("`argument` is undefined."); + } + const nlohmann::json& json_argument = argument_it.value(); + + if (json_argument.is_null()) { + return absl::InvalidArgumentError("json_argument is null."); + } + return LaExpression::FromJson(json_argument); +} + +absl::StatusOr> +LaFunctionCall::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto function, LaFunctionCall::GetFunction(json)); + MALDOCA_ASSIGN_OR_RETURN(auto argument, LaFunctionCall::GetArgument(json)); + + return absl::make_unique( + std::move(function), + std::move(argument)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/ast_gen_test.cc b/maldoca/astgen/test/lambda/ast_gen_test.cc new file mode 100644 index 0000000..95609d5 --- /dev/null +++ b/maldoca/astgen/test/lambda/ast_gen_test.cc @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + Lambda, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = + "maldoca/astgen/test/lambda/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "lambda/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/lambda", + .ir_path = "maldoca/astgen/test/lambda", + .expected_ast_header_path = + "maldoca/astgen/test/lambda/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/lambda/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "lambda/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "lambda/ast_from_json.generated.cc", + .expected_ir_tablegen_path = + "maldoca/astgen/test/" + "lambda/lair_ops.generated.td", + .expected_ast_to_ir_source_path = + "maldoca/astgen/test/" + "lambda/conversion/ast_to_lair.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/ast_to_json.generated.cc b/maldoca/astgen/test/lambda/ast_to_json.generated.cc new file mode 100644 index 0000000..1e5faf2 --- /dev/null +++ b/maldoca/astgen/test/lambda/ast_to_json.generated.cc @@ -0,0 +1,128 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/lambda/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// LaExpression +// ============================================================================= + +void LaExpression::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +// ============================================================================= +// LaVariable +// ============================================================================= + +void LaVariable::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"identifier\":" << (nlohmann::json(identifier_)).dump(); +} + +void LaVariable::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"Variable\""; + LaExpression::SerializeFields(os, needs_comma); + LaVariable::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LaFunctionDefinition +// ============================================================================= + +void LaFunctionDefinition::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"parameter\":"; + parameter_->Serialize(os); + MaybeAddComma(os, needs_comma); + os << "\"body\":"; + body_->Serialize(os); +} + +void LaFunctionDefinition::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"FunctionDefinition\""; + LaExpression::SerializeFields(os, needs_comma); + LaFunctionDefinition::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LaFunctionCall +// ============================================================================= + +void LaFunctionCall::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"function\":"; + function_->Serialize(os); + MaybeAddComma(os, needs_comma); + os << "\"argument\":"; + argument_->Serialize(os); +} + +void LaFunctionCall::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"FunctionCall\""; + LaExpression::SerializeFields(os, needs_comma); + LaFunctionCall::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/ast_ts_interface.generated b/maldoca/astgen/test/lambda/ast_ts_interface.generated new file mode 100644 index 0000000..b1fabf4 --- /dev/null +++ b/maldoca/astgen/test/lambda/ast_ts_interface.generated @@ -0,0 +1,16 @@ +interface Expression { +} + +interface Variable <: Expression { + identifier: string +} + +interface FunctionDefinition <: Expression { + parameter: Variable + body: Expression +} + +interface FunctionCall <: Expression { + function: Expression + argument: Expression +} diff --git a/maldoca/astgen/test/lambda/conversion/BUILD b/maldoca/astgen/test/lambda/conversion/BUILD new file mode 100644 index 0000000..3046582 --- /dev/null +++ b/maldoca/astgen/test/lambda/conversion/BUILD @@ -0,0 +1,53 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_applicable_licenses = ["//:license"]) + +licenses(["notice"]) + +exports_files([ + "ast_to_lair.generated.cc", +]) + +cc_library( + name = "ast_to_lair", + srcs = ["ast_to_lair.generated.cc"], + hdrs = ["ast_to_lair.h"], + deps = [ + "//maldoca/astgen/test/lambda:ast", + "//maldoca/astgen/test/lambda:ir", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_test( + name = "conversion_test", + srcs = ["conversion_test.cc"], + deps = [ + ":ast_to_lair", + "//maldoca/astgen/test:conversion_test_util", + "//maldoca/astgen/test/lambda:ast", + "//maldoca/astgen/test/lambda:ir", + "@googletest//:gtest_main", + ], +) diff --git a/maldoca/astgen/test/lambda/conversion/ast_to_lair.generated.cc b/maldoca/astgen/test/lambda/conversion/ast_to_lair.generated.cc new file mode 100644 index 0000000..e99e573 --- /dev/null +++ b/maldoca/astgen/test/lambda/conversion/ast_to_lair.generated.cc @@ -0,0 +1,96 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/lambda/conversion/ast_to_lair.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/astgen/test/lambda/ast.generated.h" +#include "maldoca/astgen/test/lambda/ir.h" + +namespace maldoca { + +LairExpressionOpInterface AstToLair::VisitExpression(const LaExpression *node) { + if (auto *variable = dynamic_cast(node)) { + return VisitVariable(variable); + } + if (auto *function_definition = dynamic_cast(node)) { + return VisitFunctionDefinition(function_definition); + } + if (auto *function_call = dynamic_cast(node)) { + return VisitFunctionCall(function_call); + } + LOG(FATAL) << "Unreachable code."; +} + +LairVariableOp AstToLair::VisitVariable(const LaVariable *node) { + mlir::StringAttr mlir_identifier = builder_.getStringAttr(node->identifier()); + return CreateExpr(node, mlir_identifier); +} + +LairVariableRefOp AstToLair::VisitVariableRef(const LaVariable *node) { + mlir::StringAttr mlir_identifier = builder_.getStringAttr(node->identifier()); + return CreateExpr(node, mlir_identifier); +} + +LairFunctionDefinitionOp AstToLair::VisitFunctionDefinition(const LaFunctionDefinition *node) { + auto op = CreateExpr(node); + mlir::Region &mlir_parameter_region = op.getParameter(); + AppendNewBlockAndPopulate(mlir_parameter_region, [&] { + mlir::Value mlir_parameter = VisitVariableRef(node->parameter()); + CreateStmt(node, mlir_parameter); + }); + mlir::Region &mlir_body_region = op.getBody(); + AppendNewBlockAndPopulate(mlir_body_region, [&] { + mlir::Value mlir_body = VisitExpression(node->body()); + CreateStmt(node, mlir_body); + }); + return op; +} + +LairFunctionCallOp AstToLair::VisitFunctionCall(const LaFunctionCall *node) { + mlir::Value mlir_function = VisitExpression(node->function()); + mlir::Value mlir_argument = VisitExpression(node->argument()); + return CreateExpr(node, mlir_function, mlir_argument); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/conversion/ast_to_lair.h b/maldoca/astgen/test/lambda/conversion/ast_to_lair.h new file mode 100644 index 0000000..c8d1c84 --- /dev/null +++ b/maldoca/astgen/test/lambda/conversion/ast_to_lair.h @@ -0,0 +1,60 @@ +#ifndef MALDOCA_ASTGEN_TEST_LAMBDA_CONVERSION_AST_TO_LAIR_H_ +#define MALDOCA_ASTGEN_TEST_LAMBDA_CONVERSION_AST_TO_LAIR_H_ + +#include +#include + +#include "mlir/IR/Builders.h" +#include "absl/cleanup/cleanup.h" +#include "maldoca/astgen/test/lambda/ast.generated.h" +#include "maldoca/astgen/test/lambda/ir.h" + +namespace maldoca { + +class AstToLair { + public: + explicit AstToLair(mlir::OpBuilder &builder) : builder_(builder) {} + + LairExpressionOpInterface VisitExpression(const LaExpression *node); + + LairVariableOp VisitVariable(const LaVariable *node); + + LairVariableRefOp VisitVariableRef(const LaVariable *node); + + LairFunctionDefinitionOp VisitFunctionDefinition( + const LaFunctionDefinition *node); + + LairFunctionCallOp VisitFunctionCall(const LaFunctionCall *node); + + private: + template + Op CreateExpr(const JsNode *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), + std::forward(args)...); + } + + template + Op CreateStmt(const JsNode *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), std::nullopt, + std::forward(args)...); + } + + void AppendNewBlockAndPopulate(mlir::Region ®ion, + std::function populate) { + // Save insertion point. + // Will revert at the end. + mlir::OpBuilder::InsertionGuard insertion_guard(builder_); + + // Insert new block and point builder to it. + mlir::Block &block = region.emplaceBlock(); + builder_.setInsertionPointToStart(&block); + + populate(); + } + + mlir::OpBuilder &builder_; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_LAMBDA_CONVERSION_AST_TO_LAIR_H_ diff --git a/maldoca/astgen/test/lambda/conversion/conversion_test.cc b/maldoca/astgen/test/lambda/conversion/conversion_test.cc new file mode 100644 index 0000000..05586a5 --- /dev/null +++ b/maldoca/astgen/test/lambda/conversion/conversion_test.cc @@ -0,0 +1,105 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/conversion_test_util.h" +#include "maldoca/astgen/test/lambda/ast.generated.h" +#include "maldoca/astgen/test/lambda/conversion/ast_to_lair.h" +#include "maldoca/astgen/test/lambda/ir.h" + +namespace maldoca { +namespace { + +TEST(ConversionTest, SimpleFunctionDefinition) { + // x => x + constexpr char kAstJsonString[] = R"( + { + "body": { + "identifier": "x", + "type": "Variable" + }, + "parameter": { + "identifier": "x", + "type": "Variable" + }, + "type": "FunctionDefinition" + } + )"; + + constexpr char kExpectedIrDump[] = R"( +module { + %0 = "lair.function_definition"() ({ + %1 = "lair.variable_ref"() <{identifier = "x"}> : () -> !lair.any + "lair.expr_region_end"(%1) : (!lair.any) -> () + }, { + %1 = "lair.variable"() <{identifier = "x"}> : () -> !lair.any + "lair.expr_region_end"(%1) : (!lair.any) -> () + }) : () -> !lair.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToLair::VisitFunctionDefinition, + .expected_ir_dump = kExpectedIrDump, + }); +} + +TEST(ConversionTest, FunctionCall) { + // (x => x)(x) + constexpr char kAstJsonString[] = R"( + { + "argument": { + "identifier": "x", + "type": "Variable" + }, + "function": { + "body": { + "identifier": "x", + "type": "Variable" + }, + "parameter": { + "identifier": "x", + "type": "Variable" + }, + "type": "FunctionDefinition" + }, + "type": "FunctionCall" + } + )"; + + const char kExpectedIrDump[] = R"( +module { + %0 = "lair.function_definition"() ({ + %3 = "lair.variable_ref"() <{identifier = "x"}> : () -> !lair.any + "lair.expr_region_end"(%3) : (!lair.any) -> () + }, { + %3 = "lair.variable"() <{identifier = "x"}> : () -> !lair.any + "lair.expr_region_end"(%3) : (!lair.any) -> () + }) : () -> !lair.any + %1 = "lair.variable"() <{identifier = "x"}> : () -> !lair.any + %2 = "lair.function_call"(%0, %1) : (!lair.any, !lair.any) -> !lair.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToLair::VisitFunctionCall, + .expected_ir_dump = kExpectedIrDump, + }); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/conversion/lair_to_ast.generated.cc b/maldoca/astgen/test/lambda/conversion/lair_to_ast.generated.cc new file mode 100644 index 0000000..154085c --- /dev/null +++ b/maldoca/astgen/test/lambda/conversion/lair_to_ast.generated.cc @@ -0,0 +1,141 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/lambda/conversion/lair_to_ast.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/base/status_macros.h" +#include "maldoca/astgen/test/lambda/ast.generated.h" +#include "maldoca/astgen/test/lambda/ir.h" + +namespace maldoca { + +absl::StatusOr> +LairToAst::VisitExpression(LairExpressionOpInterface op) { + using Ret = absl::StatusOr>; + return llvm::TypeSwitch(op) + .Case([&](LairVariableOp op) { + return VisitVariable(op); + }) + .Case([&](LairFunctionDefinitionOp op) { + return VisitFunctionDefinition(op); + }) + .Case([&](LairFunctionCallOp op) { + return VisitFunctionCall(op); + }) + .Default([&](mlir::Operation* op) { + return absl::InvalidArgumentError("Unrecognized op"); + }); +} + +absl::StatusOr> +LairToAst::VisitVariable(LairVariableOp op) { + std::string identifier = op.getIdentifierAttr().str(); + return Create( + op, + std::move(identifier)); +} + +absl::StatusOr> +LairToAst::VisitVariableRef(LairVariableRefOp op) { + std::string identifier = op.getIdentifierAttr().str(); + return Create( + op, + std::move(identifier)); +} + +absl::StatusOr> +LairToAst::VisitFunctionDefinition(LairFunctionDefinitionOp op) { + MALDOCA_ASSIGN_OR_RETURN(auto mlir_parameter_value, GetExprRegionValue(op.getParameter())); + auto parameter_op = llvm::dyn_cast(mlir_parameter_value.getDefiningOp()); + if (parameter_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected LairVariableRefOp, got ", + mlir_parameter_value.getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr parameter, VisitVariableRef(parameter_op)); + MALDOCA_ASSIGN_OR_RETURN(auto mlir_body_value, GetExprRegionValue(op.getBody())); + auto body_op = llvm::dyn_cast(mlir_body_value.getDefiningOp()); + if (body_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected LairExpressionOpInterface, got ", + mlir_body_value.getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr body, VisitExpression(body_op)); + return Create( + op, + std::move(parameter), + std::move(body)); +} + +absl::StatusOr> +LairToAst::VisitFunctionCall(LairFunctionCallOp op) { + auto function_op = llvm::dyn_cast(op.getFunction().getDefiningOp()); + if (function_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected LairExpressionOpInterface, got ", + op.getFunction().getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr function, VisitExpression(function_op)); + auto argument_op = llvm::dyn_cast(op.getArgument().getDefiningOp()); + if (argument_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected LairExpressionOpInterface, got ", + op.getArgument().getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr argument, VisitExpression(argument_op)); + return Create( + op, + std::move(function), + std::move(argument)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/interfaces.td b/maldoca/astgen/test/lambda/interfaces.td new file mode 100644 index 0000000..7a049b4 --- /dev/null +++ b/maldoca/astgen/test/lambda/interfaces.td @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definition of op interfaces used by the lir dialect. +// +// Just like we model leaf classes as MLIR ops, we model non-leaf classes as +// MLIR interfaces. +// +// For example, `FunctionDefinition` inherits from `Expression`, so we define an +// interface `LairExpressionOpInterface`. +// +// This way, we can implicitly convert an `LairFunctionDefinitionOp` to an +// `LairExpressionOpInterface`: +// +// ``` +// LairFunctionDefinitionOp function_definition = ...; +// LairExpressionOpInterface expression = function_definition; +// ``` +// +// We can also type check and explicitly convert an `LairExpressionOpInterface` +// to an `LairFunctionDefinitionOp`: +// +// ``` +// LairExpressionOpInterface expression = ...; +// if (llvm::isa(expression)) { +// auto function_definition = +// llvm::cast(expression); +// ... +// } +// ``` + +include "mlir/IR/OpBase.td" + +def LairExpressionOpInterface : OpInterface<"LairExpressionOpInterface"> { + let cppNamespace = "::maldoca"; + + let extraClassDeclaration = [{ + operator mlir::Value() { // NOLINT + return getOperation()->getResult(0); + } + }]; +} + +def LairExpressionOpInterfaceTraits : TraitList<[ + DeclareOpInterfaceMethods +]>; diff --git a/maldoca/astgen/test/lambda/ir.cc b/maldoca/astgen/test/lambda/ir.cc new file mode 100644 index 0000000..6742128 --- /dev/null +++ b/maldoca/astgen/test/lambda/ir.cc @@ -0,0 +1,96 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/test/lambda/ir.h" + +// IWYU pragma: begin_keep + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Region.h" + +// IWYU pragma: end_keep + +// ============================================================================= +// Dialect Definition +// ============================================================================= + +#include "maldoca/astgen/test/lambda/lair_dialect.cc.inc" + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void maldoca::LairDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "maldoca/astgen/test/lambda/lair_types.cc.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "maldoca/astgen/test/lambda/lair_ops.cc.inc" + >(); +} + +// ============================================================================= +// Dialect Interface Definitions +// ============================================================================= + +#include "maldoca/astgen/test/lambda/interfaces.cc.inc" + +// ============================================================================= +// Dialect Type Definitions +// ============================================================================= + +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/lambda/lair_types.cc.inc" + +// ============================================================================= +// Dialect Op Definitions +// ============================================================================= + +#define GET_OP_CLASSES +#include "maldoca/astgen/test/lambda/lair_ops.cc.inc" + +// ============================================================================= +// Utils +// ============================================================================= + +namespace maldoca { + +bool IsExprRegion(mlir::Region ®ion) { + // Region must have exactly one block. + if (!llvm::hasSingleElement(region)) { + return false; + } + + mlir::Block &block = region.front(); + + // Block must have at least one op (terminator). + if (block.empty()) { + return false; + } + + auto *terminator = &block.back(); + return llvm::isa(terminator); +} + +} // namespace maldoca diff --git a/maldoca/astgen/test/lambda/ir.h b/maldoca/astgen/test/lambda/ir.h new file mode 100644 index 0000000..ad7a7a3 --- /dev/null +++ b/maldoca/astgen/test/lambda/ir.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LAMBDA_IR_H_ +#define MALDOCA_ASTGEN_TEST_LAMBDA_IR_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +namespace maldoca { + +// Checks that the region contains a single block that terminates with +// LairExprRegionEnd. This means that this region calculates a single +// expression. +bool IsExprRegion(mlir::Region ®ion); + +} // namespace maldoca + +// Include the auto-generated header file containing the declaration of the LAIR +// dialect. +#include "maldoca/astgen/test/lambda/lair_dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// LAIR interfaces. +#include "maldoca/astgen/test/lambda/interfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the +// LAIR types. +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/lambda/lair_types.h.inc" + +// Include the auto-generated header file containing the declarations of the +// LAIR operations. +#define GET_OP_CLASSES +#include "maldoca/astgen/test/lambda/lair_ops.h.inc" + +#endif // MALDOCA_ASTGEN_TEST_LAMBDA_IR_H_ diff --git a/maldoca/astgen/test/lambda/lair_dialect.td b/maldoca/astgen/test/lambda/lair_dialect.td new file mode 100644 index 0000000..54e7f3b --- /dev/null +++ b/maldoca/astgen/test/lambda/lair_dialect.td @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LAMBDDA_LIR_DIALECT_TD_ +#define MALDOCA_ASTGEN_TEST_LAMBDDA_LIR_DIALECT_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def Lair_Dialect : Dialect { + let name = "lair"; + let cppNamespace = "::maldoca"; + + let description = [{ + The LambdaIR, a test IR that models lambda calculus. All ops and fields are + directly mapped from the AST. + + In LAIR, we showcase ASTGen's ability of putting expression-likes in + regions. + }]; + let useDefaultTypePrinterParser = 1; +} + +class Lair_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = ?; +} + +class Lair_Op traits = []> : + Op; + +def ExprRegion : Region>; + +#endif // MALDOCA_ASTGEN_TEST_LAMBDDA_LIR_DIALECT_TD_ diff --git a/maldoca/astgen/test/lambda/lair_ops.generated.td b/maldoca/astgen/test/lambda/lair_ops.generated.td new file mode 100644 index 0000000..a8bddbb --- /dev/null +++ b/maldoca/astgen/test/lambda/lair_ops.generated.td @@ -0,0 +1,160 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_LAMBDA_LAIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_LAMBDA_LAIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/lambda/interfaces.td" +include "maldoca/astgen/test/lambda/lair_dialect.td" +include "maldoca/astgen/test/lambda/lair_types.td" + +// lair.*_region_end: An artificial op at the end of a region to collect +// expression-related values. +// +// Take lair.exprs_region_end as example: +// ====================================== +// +// Consider the following function declaration: +// ``` +// function foo(arg1, arg2 = defaultValue) { +// ... +// } +// ``` +// +// We lower it to the following IR (simplified): +// ``` +// %0 = lair.identifier_ref {"foo"} +// lair.function_declaration(%0) ( +// // params +// { +// %1 = lair.identifier_ref {"a"} +// %2 = lair.identifier_ref {"b"} +// %3 = lair.identifier {"defaultValue"} +// %4 = lair.assignment_pattern_ref(%2, %3) +// lair.exprs_region_end(%1, %4) +// }, +// // body +// { +// ... +// } +// ) +// ``` +// +// We can see that: +// +// 1. We put the parameter-related ops in a region, instead of taking them as +// normal arguments. In other words, we don't do this: +// +// ``` +// %0 = lair.identifier_ref {"foo"} +// %1 = lair.identifier_ref {"a"} +// %2 = lair.identifier_ref {"b"} +// %3 = lair.identifier {"defaultValue"} +// %4 = lair.assignment_pattern_ref(%2, %3) +// lair.function_declaration(%0, [%1, %4]) ( +// // body +// { +// ... +// } +// ) +// ``` +// +// The reason is that sometimes an argument might have a default value, and +// the evaluation of that default value happens once for each function call +// (i.e. it happens "within" the function). If we take the parameter as +// normal argument, then %3 is only evaluated once - at function definition +// time. +// +// 2. Even though the function has two parameters, we use 4 ops to represent +// them. This is because some parameters are more complex and require more +// than one op. +// +// 3. We use "lair.exprs_region_end" to list the "top-level" ops for the +// parameters. In the example above, ops [%2, %3, %4] all represent the +// parameter "b = defaultValue", but %4 is the top-level one. In other words, +// %4 is the root of the tree [%2, %3, %4]. +// +// 4. Strictly speaking, we don't really need "lair.exprs_region_end". The ops +// within the "params" region form several trees, and we can figure out what +// the roots are (a root is an op whose return value is not used by any other +// op). So the use of "lair.exprs_region_end" is mostly for convenience. +def LairExprRegionEndOp : Lair_Op<"expr_region_end", [Terminator]> { + let arguments = (ins + AnyType: $argument + ); +} + +def LairVariableOp : Lair_Op< + "variable", [ + LairExpressionOpInterfaceTraits + ]> { + let arguments = (ins + StrAttr: $identifier + ); + + let results = (outs + LairAnyType + ); +} + +def LairVariableRefOp : Lair_Op<"variable_ref", []> { + let arguments = (ins + StrAttr: $identifier + ); + + let results = (outs + LairAnyType + ); +} + +def LairFunctionDefinitionOp : Lair_Op< + "function_definition", [ + LairExpressionOpInterfaceTraits, + NoTerminator + ]> { + let regions = (region + ExprRegion: $parameter, + ExprRegion: $body + ); + + let results = (outs + LairAnyType + ); +} + +def LairFunctionCallOp : Lair_Op< + "function_call", [ + LairExpressionOpInterfaceTraits + ]> { + let arguments = (ins + AnyType: $function, + AnyType: $argument + ); + + let results = (outs + LairAnyType + ); +} + +#endif // MALDOCA_ASTGEN_TEST_LAMBDA_LAIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/lambda/lair_ops.td b/maldoca/astgen/test/lambda/lair_ops.td new file mode 100644 index 0000000..42f78ed --- /dev/null +++ b/maldoca/astgen/test/lambda/lair_ops.td @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LAMBDA_LIR_OPS_TD_ +#define MALDOCA_ASTGEN_TEST_LAMBDA_LIR_OPS_TD_ + +// Import the generated ops. +include "maldoca/astgen/test/lambda/lair_ops.generated.td" + +#endif // MALDOCA_ASTGEN_TEST_LAMBDA_LIR_OPS_TD_ diff --git a/maldoca/astgen/test/lambda/lair_types.td b/maldoca/astgen/test/lambda/lair_types.td new file mode 100644 index 0000000..c8859c1 --- /dev/null +++ b/maldoca/astgen/test/lambda/lair_types.td @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LAMBDA_LAIR_TYPES_TD_ +#define MALDOCA_ASTGEN_TEST_LAMBDA_LAIR_TYPES_TD_ + +include "maldoca/astgen/test/lambda/lair_dialect.td" + +def LairAnyType : Lair_Type<"LairAny"> { + let summary = "A placeholder singleton type."; + let mnemonic = "any"; + let assemblyFormat = ""; +} + +#endif // MALDOCA_ASTGEN_TEST_LAMBDA_LAIR_TYPES_TD_ diff --git a/maldoca/astgen/test/list/BUILD b/maldoca/astgen/test/list/BUILD new file mode 100644 index 0000000..d21e963 --- /dev/null +++ b/maldoca/astgen/test/list/BUILD @@ -0,0 +1,185 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//maldoca/astgen:__subpackages__", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + "liir_ops.generated.td", + "//maldoca/astgen/test/list/conversion:ast_to_liir.generated.cc", + "//maldoca/astgen/test/list/conversion:liir_to_ast.generated.cc", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = ["ast.generated.h"], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) + +td_library( + name = "interfaces_td_files", + srcs = [ + "interfaces.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "interfaces_inc_gen", + tbl_outs = { + "interfaces.h.inc": ["-gen-op-interface-decls"], + "interfaces.cc.inc": ["-gen-op-interface-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "interfaces.td", + deps = [":interfaces_td_files"], +) + +td_library( + name = "liir_dialect_td_files", + srcs = [ + "liir_dialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "liir_dialect_inc_gen", + tbl_outs = { + "liir_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=liir", + ], + "liir_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=liir", + ], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "liir_dialect.td", + deps = [":liir_dialect_td_files"], +) + +td_library( + name = "liir_types_td_files", + srcs = [ + "liir_types.td", + ], + deps = [ + ":liir_dialect_td_files", + ], +) + +gentbl_cc_library( + name = "liir_types_inc_gen", + tbl_outs = { + "liir_types.h.inc": ["-gen-typedef-decls"], + "liir_types.cc.inc": ["-gen-typedef-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "liir_types.td", + deps = [":liir_types_td_files"], +) + +td_library( + name = "liir_ops_td_files", + srcs = [ + "liir_ops.generated.td", + "liir_ops.td", + ], + deps = [ + ":interfaces_td_files", + ":liir_dialect_td_files", + ":liir_types_td_files", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "liir_ops_inc_gen", + tbl_outs = { + "liir_ops.h.inc": ["-gen-op-decls"], + "liir_ops.cc.inc": ["-gen-op-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "liir_ops.td", + deps = [":liir_ops_td_files"], +) + +cc_library( + name = "ir", + srcs = ["ir.cc"], + hdrs = ["ir.h"], + deps = [ + ":interfaces_inc_gen", + ":liir_dialect_inc_gen", + ":liir_ops_inc_gen", + ":liir_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + ], +) diff --git a/maldoca/astgen/test/list/ast.generated.cc b/maldoca/astgen/test/list/ast.generated.cc new file mode 100644 index 0000000..c674a3a --- /dev/null +++ b/maldoca/astgen/test/list/ast.generated.cc @@ -0,0 +1,305 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/list/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// LiClass1 +// ============================================================================= + +// ============================================================================= +// LiClass2 +// ============================================================================= + +// ============================================================================= +// LiSimpleList +// ============================================================================= + +LiSimpleList::LiSimpleList( + std::vector strings, + std::vector> operations) + : strings_(std::move(strings)), + operations_(std::move(operations)) {} + +std::vector* LiSimpleList::strings() { + return &strings_; +} + +const std::vector* LiSimpleList::strings() const { + return &strings_; +} + +void LiSimpleList::set_strings(std::vector strings) { + strings_ = std::move(strings); +} + +std::vector>* LiSimpleList::operations() { + return &operations_; +} + +const std::vector>* LiSimpleList::operations() const { + return &operations_; +} + +void LiSimpleList::set_operations(std::vector> operations) { + operations_ = std::move(operations); +} + +// ============================================================================= +// LiOptionalList +// ============================================================================= + +LiOptionalList::LiOptionalList( + std::optional> strings) + : strings_(std::move(strings)) {} + +std::optional*> LiOptionalList::strings() { + if (!strings_.has_value()) { + return std::nullopt; + } else { + return &strings_.value(); + } +} + +std::optional*> LiOptionalList::strings() const { + if (!strings_.has_value()) { + return std::nullopt; + } else { + return &strings_.value(); + } +} + +void LiOptionalList::set_strings(std::optional> strings) { + strings_ = std::move(strings); +} + +// ============================================================================= +// LiListOfOptional +// ============================================================================= + +LiListOfOptional::LiListOfOptional( + std::vector> strings, + std::vector>> operations) + : strings_(std::move(strings)), + operations_(std::move(operations)) {} + +std::vector>* LiListOfOptional::strings() { + return &strings_; +} + +const std::vector>* LiListOfOptional::strings() const { + return &strings_; +} + +void LiListOfOptional::set_strings(std::vector> strings) { + strings_ = std::move(strings); +} + +std::vector>>* LiListOfOptional::operations() { + return &operations_; +} + +const std::vector>>* LiListOfOptional::operations() const { + return &operations_; +} + +void LiListOfOptional::set_operations(std::vector>> operations) { + operations_ = std::move(operations); +} + +// ============================================================================= +// LiListOfVariant +// ============================================================================= + +LiListOfVariant::LiListOfVariant( + std::vector> variants, + std::vector, std::unique_ptr>> operations) + : variants_(std::move(variants)), + operations_(std::move(operations)) {} + +std::vector>* LiListOfVariant::variants() { + return &variants_; +} + +const std::vector>* LiListOfVariant::variants() const { + return &variants_; +} + +void LiListOfVariant::set_variants(std::vector> variants) { + variants_ = std::move(variants); +} + +std::vector, std::unique_ptr>>* LiListOfVariant::operations() { + return &operations_; +} + +const std::vector, std::unique_ptr>>* LiListOfVariant::operations() const { + return &operations_; +} + +void LiListOfVariant::set_operations(std::vector, std::unique_ptr>> operations) { + operations_ = std::move(operations); +} + +// ============================================================================= +// LiOptionalListOfOptional +// ============================================================================= + +LiOptionalListOfOptional::LiOptionalListOfOptional( + std::optional>> variants) + : variants_(std::move(variants)) {} + +std::optional>*> LiOptionalListOfOptional::variants() { + if (!variants_.has_value()) { + return std::nullopt; + } else { + return &variants_.value(); + } +} + +std::optional>*> LiOptionalListOfOptional::variants() const { + if (!variants_.has_value()) { + return std::nullopt; + } else { + return &variants_.value(); + } +} + +void LiOptionalListOfOptional::set_variants(std::optional>> variants) { + variants_ = std::move(variants); +} + +// ============================================================================= +// LiOptionalListOfVariant +// ============================================================================= + +LiOptionalListOfVariant::LiOptionalListOfVariant( + std::optional>> variants) + : variants_(std::move(variants)) {} + +std::optional>*> LiOptionalListOfVariant::variants() { + if (!variants_.has_value()) { + return std::nullopt; + } else { + return &variants_.value(); + } +} + +std::optional>*> LiOptionalListOfVariant::variants() const { + if (!variants_.has_value()) { + return std::nullopt; + } else { + return &variants_.value(); + } +} + +void LiOptionalListOfVariant::set_variants(std::optional>> variants) { + variants_ = std::move(variants); +} + +// ============================================================================= +// LiListOfOptionalVariant +// ============================================================================= + +LiListOfOptionalVariant::LiListOfOptionalVariant( + std::vector>> variants, + std::vector, std::unique_ptr>>> operations) + : variants_(std::move(variants)), + operations_(std::move(operations)) {} + +std::vector>>* LiListOfOptionalVariant::variants() { + return &variants_; +} + +const std::vector>>* LiListOfOptionalVariant::variants() const { + return &variants_; +} + +void LiListOfOptionalVariant::set_variants(std::vector>> variants) { + variants_ = std::move(variants); +} + +std::vector, std::unique_ptr>>>* LiListOfOptionalVariant::operations() { + return &operations_; +} + +const std::vector, std::unique_ptr>>>* LiListOfOptionalVariant::operations() const { + return &operations_; +} + +void LiListOfOptionalVariant::set_operations(std::vector, std::unique_ptr>>> operations) { + operations_ = std::move(operations); +} + +// ============================================================================= +// LiOptionalListOfOptionalVariant +// ============================================================================= + +LiOptionalListOfOptionalVariant::LiOptionalListOfOptionalVariant( + std::optional>>> variants) + : variants_(std::move(variants)) {} + +std::optional>>*> LiOptionalListOfOptionalVariant::variants() { + if (!variants_.has_value()) { + return std::nullopt; + } else { + return &variants_.value(); + } +} + +std::optional>>*> LiOptionalListOfOptionalVariant::variants() const { + if (!variants_.has_value()) { + return std::nullopt; + } else { + return &variants_.value(); + } +} + +void LiOptionalListOfOptionalVariant::set_variants(std::optional>>> variants) { + variants_ = std::move(variants); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/list/ast.generated.h b/maldoca/astgen/test/list/ast.generated.h new file mode 100644 index 0000000..af865d6 --- /dev/null +++ b/maldoca/astgen/test/list/ast.generated.h @@ -0,0 +1,313 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_LIST_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_LIST_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +class LiClass1 { + public: + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class LiClass2 { + public: + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class LiSimpleList { + public: + explicit LiSimpleList( + std::vector strings, + std::vector> operations); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::vector* strings(); + const std::vector* strings() const; + void set_strings(std::vector strings); + + std::vector>* operations(); + const std::vector>* operations() const; + void set_operations(std::vector> operations); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetStrings(const nlohmann::json& json); + static absl::StatusOr>> GetOperations(const nlohmann::json& json); + + private: + std::vector strings_; + std::vector> operations_; +}; + +class LiOptionalList { + public: + explicit LiOptionalList( + std::optional> strings); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::optional*> strings(); + std::optional*> strings() const; + void set_strings(std::optional> strings); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr>> GetStrings(const nlohmann::json& json); + + private: + std::optional> strings_; +}; + +class LiListOfOptional { + public: + explicit LiListOfOptional( + std::vector> strings, + std::vector>> operations); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::vector>* strings(); + const std::vector>* strings() const; + void set_strings(std::vector> strings); + + std::vector>>* operations(); + const std::vector>>* operations() const; + void set_operations(std::vector>> operations); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr>> GetStrings(const nlohmann::json& json); + static absl::StatusOr>>> GetOperations(const nlohmann::json& json); + + private: + std::vector> strings_; + std::vector>> operations_; +}; + +class LiListOfVariant { + public: + explicit LiListOfVariant( + std::vector> variants, + std::vector, std::unique_ptr>> operations); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::vector>* variants(); + const std::vector>* variants() const; + void set_variants(std::vector> variants); + + std::vector, std::unique_ptr>>* operations(); + const std::vector, std::unique_ptr>>* operations() const; + void set_operations(std::vector, std::unique_ptr>> operations); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr>> GetVariants(const nlohmann::json& json); + static absl::StatusOr, std::unique_ptr>>> GetOperations(const nlohmann::json& json); + + private: + std::vector> variants_; + std::vector, std::unique_ptr>> operations_; +}; + +class LiOptionalListOfOptional { + public: + explicit LiOptionalListOfOptional( + std::optional>> variants); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::optional>*> variants(); + std::optional>*> variants() const; + void set_variants(std::optional>> variants); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr>>> GetVariants(const nlohmann::json& json); + + private: + std::optional>> variants_; +}; + +class LiOptionalListOfVariant { + public: + explicit LiOptionalListOfVariant( + std::optional>> variants); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::optional>*> variants(); + std::optional>*> variants() const; + void set_variants(std::optional>> variants); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr>>> GetVariants(const nlohmann::json& json); + + private: + std::optional>> variants_; +}; + +class LiListOfOptionalVariant { + public: + explicit LiListOfOptionalVariant( + std::vector>> variants, + std::vector, std::unique_ptr>>> operations); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::vector>>* variants(); + const std::vector>>* variants() const; + void set_variants(std::vector>> variants); + + std::vector, std::unique_ptr>>>* operations(); + const std::vector, std::unique_ptr>>>* operations() const; + void set_operations(std::vector, std::unique_ptr>>> operations); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr>>> GetVariants(const nlohmann::json& json); + static absl::StatusOr, std::unique_ptr>>>> GetOperations(const nlohmann::json& json); + + private: + std::vector>> variants_; + std::vector, std::unique_ptr>>> operations_; +}; + +class LiOptionalListOfOptionalVariant { + public: + explicit LiOptionalListOfOptionalVariant( + std::optional>>> variants); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::optional>>*> variants(); + std::optional>>*> variants() const; + void set_variants(std::optional>>> variants); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr>>>> GetVariants(const nlohmann::json& json); + + private: + std::optional>>> variants_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_LIST_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/list/ast_def.textproto b/maldoca/astgen/test/list/ast_def.textproto new file mode 100644 index 0000000..1882ac3 --- /dev/null +++ b/maldoca/astgen/test/list/ast_def.textproto @@ -0,0 +1,242 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "li" + +# interface Class1 {} +nodes { + name: "Class1" + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface Class2 {} +nodes { + name: "Class2" + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface SimpleList { +# strings: [string] +# operations: [Class1] +# } +nodes { + name: "SimpleList" + fields { + name: "strings" + type { + list { + element_type { string {} } + } + } + kind: FIELD_KIND_ATTR + } + fields { + name: "operations" + type { + list { + element_type { class: "Class1" } + } + } + kind: FIELD_KIND_RVAL + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface OptionalList { +# strings? : [string] +# } +nodes { + name: "OptionalList" + fields { + name: "strings" + optionalness: OPTIONALNESS_MAYBE_UNDEFINED + type { + list { + element_type { string {} } + } + } + kind: FIELD_KIND_ATTR + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface ListOfOptional { +# strings: [ string | null ] +# operations: [ Class1 | null ] +# } +nodes { + name: "ListOfOptional" + fields { + name: "strings" + type { + list { + element_type { string {} } + element_maybe_null: true + } + } + kind: FIELD_KIND_ATTR + } + fields { + name: "operations" + type { + list { + element_type { class: "Class1" } + element_maybe_null: true + } + } + kind: FIELD_KIND_RVAL + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface ListOfVariant { +# variants: [ bool | string ] +# operations: [ Class1 | Class2 ] +# } +nodes { + name: "ListOfVariant" + fields { + name: "variants" + type { + list { + element_type { + variant { + types { bool {} } + types { string {} } + } + } + } + } + kind: FIELD_KIND_ATTR + } + fields { + name: "operations" + type { + list { + element_type { + variant { + types { class: "Class1" } + types { class: "Class2" } + } + } + } + } + kind: FIELD_KIND_RVAL + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface OptionalListOfOptional { +# variants? : [ string | null ] +# } +nodes { + name: "OptionalListOfOptional" + fields { + name: "variants" + optionalness: OPTIONALNESS_MAYBE_UNDEFINED + type { + list { + element_type { string {} } + element_maybe_null: true + } + } + kind: FIELD_KIND_ATTR + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface OptionalListOfVariant { +# variants? : [ boolean | string ] +# } +nodes { + name: "OptionalListOfVariant" + fields { + name: "variants" + optionalness: OPTIONALNESS_MAYBE_UNDEFINED + type { + list { + element_type { + variant { + types { bool {} } + types { string {} } + } + } + } + } + kind: FIELD_KIND_ATTR + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface ListOfOptionalVariant { +# variants: [ boolean | string | null ] +# operations: [ Class1 | Class2 | null ] +# } +nodes { + name: "ListOfOptionalVariant" + fields { + name: "variants" + type { + list { + element_type { + variant { + types { bool {} } + types { string {} } + } + } + element_maybe_null: true + } + } + kind: FIELD_KIND_ATTR + } + fields { + name: "operations" + type { + list { + element_type { + variant { + types { class: "Class1" } + types { class: "Class2" } + } + } + element_maybe_null: true + } + } + kind: FIELD_KIND_RVAL + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface OptionalListOfOptionalVariant { +# variants? : [ boolean | string | null ] +# } +nodes { + name: "OptionalListOfOptionalVariant" + fields { + name: "variants" + optionalness: OPTIONALNESS_MAYBE_UNDEFINED + type { + list { + element_type { + variant { + types { bool {} } + types { string {} } + } + } + element_maybe_null: true + } + } + kind: FIELD_KIND_ATTR + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} diff --git a/maldoca/astgen/test/list/ast_from_json.generated.cc b/maldoca/astgen/test/list/ast_from_json.generated.cc new file mode 100644 index 0000000..2a48abb --- /dev/null +++ b/maldoca/astgen/test/list/ast_from_json.generated.cc @@ -0,0 +1,608 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/list/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +// ============================================================================= +// LiClass1 +// ============================================================================= + +static bool IsClass1(const nlohmann::json& json) { + if (!json.is_object()) { + return false; + } + return true; +} + +absl::StatusOr> +LiClass1::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + + return absl::make_unique( + ); +} + +// ============================================================================= +// LiClass2 +// ============================================================================= + +static bool IsClass2(const nlohmann::json& json) { + if (!json.is_object()) { + return false; + } + return true; +} + +absl::StatusOr> +LiClass2::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + + return absl::make_unique( + ); +} + +// ============================================================================= +// LiSimpleList +// ============================================================================= + +absl::StatusOr> +LiSimpleList::GetStrings(const nlohmann::json& json) { + auto strings_it = json.find("strings"); + if (strings_it == json.end()) { + return absl::InvalidArgumentError("`strings` is undefined."); + } + const nlohmann::json& json_strings = strings_it.value(); + + if (json_strings.is_null()) { + return absl::InvalidArgumentError("json_strings is null."); + } + if (!json_strings.is_array()) { + return absl::InvalidArgumentError("json_strings expected to be array."); + } + + std::vector strings; + for (const nlohmann::json& json_strings_element : json_strings) { + if (json_strings_element.is_null()) { + return absl::InvalidArgumentError("json_strings_element is null."); + } + if (!json_strings_element.is_string()) { + return absl::InvalidArgumentError("Expecting json_strings_element.is_string()."); + } + auto strings_element = json_strings_element.get(); + strings.push_back(std::move(strings_element)); + } + return strings; +} + +absl::StatusOr>> +LiSimpleList::GetOperations(const nlohmann::json& json) { + auto operations_it = json.find("operations"); + if (operations_it == json.end()) { + return absl::InvalidArgumentError("`operations` is undefined."); + } + const nlohmann::json& json_operations = operations_it.value(); + + if (json_operations.is_null()) { + return absl::InvalidArgumentError("json_operations is null."); + } + if (!json_operations.is_array()) { + return absl::InvalidArgumentError("json_operations expected to be array."); + } + + std::vector> operations; + for (const nlohmann::json& json_operations_element : json_operations) { + if (json_operations_element.is_null()) { + return absl::InvalidArgumentError("json_operations_element is null."); + } + MALDOCA_ASSIGN_OR_RETURN(auto operations_element, LiClass1::FromJson(json_operations_element)); + operations.push_back(std::move(operations_element)); + } + return operations; +} + +absl::StatusOr> +LiSimpleList::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto strings, LiSimpleList::GetStrings(json)); + MALDOCA_ASSIGN_OR_RETURN(auto operations, LiSimpleList::GetOperations(json)); + + return absl::make_unique( + std::move(strings), + std::move(operations)); +} + +// ============================================================================= +// LiOptionalList +// ============================================================================= + +absl::StatusOr>> +LiOptionalList::GetStrings(const nlohmann::json& json) { + auto strings_it = json.find("strings"); + if (strings_it == json.end()) { + return std::nullopt; + } + const nlohmann::json& json_strings = strings_it.value(); + + if (json_strings.is_null()) { + return absl::InvalidArgumentError("json_strings is null."); + } + if (!json_strings.is_array()) { + return absl::InvalidArgumentError("json_strings expected to be array."); + } + + std::vector strings; + for (const nlohmann::json& json_strings_element : json_strings) { + if (json_strings_element.is_null()) { + return absl::InvalidArgumentError("json_strings_element is null."); + } + if (!json_strings_element.is_string()) { + return absl::InvalidArgumentError("Expecting json_strings_element.is_string()."); + } + auto strings_element = json_strings_element.get(); + strings.push_back(std::move(strings_element)); + } + return strings; +} + +absl::StatusOr> +LiOptionalList::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto strings, LiOptionalList::GetStrings(json)); + + return absl::make_unique( + std::move(strings)); +} + +// ============================================================================= +// LiListOfOptional +// ============================================================================= + +absl::StatusOr>> +LiListOfOptional::GetStrings(const nlohmann::json& json) { + auto strings_it = json.find("strings"); + if (strings_it == json.end()) { + return absl::InvalidArgumentError("`strings` is undefined."); + } + const nlohmann::json& json_strings = strings_it.value(); + + if (json_strings.is_null()) { + return absl::InvalidArgumentError("json_strings is null."); + } + if (!json_strings.is_array()) { + return absl::InvalidArgumentError("json_strings expected to be array."); + } + + std::vector> strings; + for (const nlohmann::json& json_strings_element : json_strings) { + std::optional strings_element; + if (!json_strings_element.is_null()) { + if (!json_strings_element.is_string()) { + return absl::InvalidArgumentError("Expecting json_strings_element.is_string()."); + } + strings_element = json_strings_element.get(); + } + strings.push_back(std::move(strings_element)); + } + return strings; +} + +absl::StatusOr>>> +LiListOfOptional::GetOperations(const nlohmann::json& json) { + auto operations_it = json.find("operations"); + if (operations_it == json.end()) { + return absl::InvalidArgumentError("`operations` is undefined."); + } + const nlohmann::json& json_operations = operations_it.value(); + + if (json_operations.is_null()) { + return absl::InvalidArgumentError("json_operations is null."); + } + if (!json_operations.is_array()) { + return absl::InvalidArgumentError("json_operations expected to be array."); + } + + std::vector>> operations; + for (const nlohmann::json& json_operations_element : json_operations) { + std::optional> operations_element; + if (!json_operations_element.is_null()) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, LiClass1::FromJson(json_operations_element)); + } + operations.push_back(std::move(operations_element)); + } + return operations; +} + +absl::StatusOr> +LiListOfOptional::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto strings, LiListOfOptional::GetStrings(json)); + MALDOCA_ASSIGN_OR_RETURN(auto operations, LiListOfOptional::GetOperations(json)); + + return absl::make_unique( + std::move(strings), + std::move(operations)); +} + +// ============================================================================= +// LiListOfVariant +// ============================================================================= + +absl::StatusOr>> +LiListOfVariant::GetVariants(const nlohmann::json& json) { + auto variants_it = json.find("variants"); + if (variants_it == json.end()) { + return absl::InvalidArgumentError("`variants` is undefined."); + } + const nlohmann::json& json_variants = variants_it.value(); + + if (json_variants.is_null()) { + return absl::InvalidArgumentError("json_variants is null."); + } + if (!json_variants.is_array()) { + return absl::InvalidArgumentError("json_variants expected to be array."); + } + + std::vector> variants; + for (const nlohmann::json& json_variants_element : json_variants) { + if (json_variants_element.is_null()) { + return absl::InvalidArgumentError("json_variants_element is null."); + } + std::variant variants_element; + if (json_variants_element.is_boolean()) { + variants_element = json_variants_element.get(); + } else if (json_variants_element.is_string()) { + variants_element = json_variants_element.get(); + } else { + auto result = absl::InvalidArgumentError("json_variants_element has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_variants_element.dump()}); + return result; + } + variants.push_back(std::move(variants_element)); + } + return variants; +} + +absl::StatusOr, std::unique_ptr>>> +LiListOfVariant::GetOperations(const nlohmann::json& json) { + auto operations_it = json.find("operations"); + if (operations_it == json.end()) { + return absl::InvalidArgumentError("`operations` is undefined."); + } + const nlohmann::json& json_operations = operations_it.value(); + + if (json_operations.is_null()) { + return absl::InvalidArgumentError("json_operations is null."); + } + if (!json_operations.is_array()) { + return absl::InvalidArgumentError("json_operations expected to be array."); + } + + std::vector, std::unique_ptr>> operations; + for (const nlohmann::json& json_operations_element : json_operations) { + if (json_operations_element.is_null()) { + return absl::InvalidArgumentError("json_operations_element is null."); + } + std::variant, std::unique_ptr> operations_element; + if (IsClass1(json_operations_element)) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, LiClass1::FromJson(json_operations_element)); + } else if (IsClass2(json_operations_element)) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, LiClass2::FromJson(json_operations_element)); + } else { + auto result = absl::InvalidArgumentError("json_operations_element has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_operations_element.dump()}); + return result; + } + operations.push_back(std::move(operations_element)); + } + return operations; +} + +absl::StatusOr> +LiListOfVariant::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto variants, LiListOfVariant::GetVariants(json)); + MALDOCA_ASSIGN_OR_RETURN(auto operations, LiListOfVariant::GetOperations(json)); + + return absl::make_unique( + std::move(variants), + std::move(operations)); +} + +// ============================================================================= +// LiOptionalListOfOptional +// ============================================================================= + +absl::StatusOr>>> +LiOptionalListOfOptional::GetVariants(const nlohmann::json& json) { + auto variants_it = json.find("variants"); + if (variants_it == json.end()) { + return std::nullopt; + } + const nlohmann::json& json_variants = variants_it.value(); + + if (json_variants.is_null()) { + return absl::InvalidArgumentError("json_variants is null."); + } + if (!json_variants.is_array()) { + return absl::InvalidArgumentError("json_variants expected to be array."); + } + + std::vector> variants; + for (const nlohmann::json& json_variants_element : json_variants) { + std::optional variants_element; + if (!json_variants_element.is_null()) { + if (!json_variants_element.is_string()) { + return absl::InvalidArgumentError("Expecting json_variants_element.is_string()."); + } + variants_element = json_variants_element.get(); + } + variants.push_back(std::move(variants_element)); + } + return variants; +} + +absl::StatusOr> +LiOptionalListOfOptional::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto variants, LiOptionalListOfOptional::GetVariants(json)); + + return absl::make_unique( + std::move(variants)); +} + +// ============================================================================= +// LiOptionalListOfVariant +// ============================================================================= + +absl::StatusOr>>> +LiOptionalListOfVariant::GetVariants(const nlohmann::json& json) { + auto variants_it = json.find("variants"); + if (variants_it == json.end()) { + return std::nullopt; + } + const nlohmann::json& json_variants = variants_it.value(); + + if (json_variants.is_null()) { + return absl::InvalidArgumentError("json_variants is null."); + } + if (!json_variants.is_array()) { + return absl::InvalidArgumentError("json_variants expected to be array."); + } + + std::vector> variants; + for (const nlohmann::json& json_variants_element : json_variants) { + if (json_variants_element.is_null()) { + return absl::InvalidArgumentError("json_variants_element is null."); + } + std::variant variants_element; + if (json_variants_element.is_boolean()) { + variants_element = json_variants_element.get(); + } else if (json_variants_element.is_string()) { + variants_element = json_variants_element.get(); + } else { + auto result = absl::InvalidArgumentError("json_variants_element has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_variants_element.dump()}); + return result; + } + variants.push_back(std::move(variants_element)); + } + return variants; +} + +absl::StatusOr> +LiOptionalListOfVariant::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto variants, LiOptionalListOfVariant::GetVariants(json)); + + return absl::make_unique( + std::move(variants)); +} + +// ============================================================================= +// LiListOfOptionalVariant +// ============================================================================= + +absl::StatusOr>>> +LiListOfOptionalVariant::GetVariants(const nlohmann::json& json) { + auto variants_it = json.find("variants"); + if (variants_it == json.end()) { + return absl::InvalidArgumentError("`variants` is undefined."); + } + const nlohmann::json& json_variants = variants_it.value(); + + if (json_variants.is_null()) { + return absl::InvalidArgumentError("json_variants is null."); + } + if (!json_variants.is_array()) { + return absl::InvalidArgumentError("json_variants expected to be array."); + } + + std::vector>> variants; + for (const nlohmann::json& json_variants_element : json_variants) { + std::optional> variants_element; + if (!json_variants_element.is_null()) { + if (json_variants_element.is_boolean()) { + variants_element = json_variants_element.get(); + } else if (json_variants_element.is_string()) { + variants_element = json_variants_element.get(); + } else { + auto result = absl::InvalidArgumentError("json_variants_element has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_variants_element.dump()}); + return result; + } + } + variants.push_back(std::move(variants_element)); + } + return variants; +} + +absl::StatusOr, std::unique_ptr>>>> +LiListOfOptionalVariant::GetOperations(const nlohmann::json& json) { + auto operations_it = json.find("operations"); + if (operations_it == json.end()) { + return absl::InvalidArgumentError("`operations` is undefined."); + } + const nlohmann::json& json_operations = operations_it.value(); + + if (json_operations.is_null()) { + return absl::InvalidArgumentError("json_operations is null."); + } + if (!json_operations.is_array()) { + return absl::InvalidArgumentError("json_operations expected to be array."); + } + + std::vector, std::unique_ptr>>> operations; + for (const nlohmann::json& json_operations_element : json_operations) { + std::optional, std::unique_ptr>> operations_element; + if (!json_operations_element.is_null()) { + if (IsClass1(json_operations_element)) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, LiClass1::FromJson(json_operations_element)); + } else if (IsClass2(json_operations_element)) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, LiClass2::FromJson(json_operations_element)); + } else { + auto result = absl::InvalidArgumentError("json_operations_element has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_operations_element.dump()}); + return result; + } + } + operations.push_back(std::move(operations_element)); + } + return operations; +} + +absl::StatusOr> +LiListOfOptionalVariant::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto variants, LiListOfOptionalVariant::GetVariants(json)); + MALDOCA_ASSIGN_OR_RETURN(auto operations, LiListOfOptionalVariant::GetOperations(json)); + + return absl::make_unique( + std::move(variants), + std::move(operations)); +} + +// ============================================================================= +// LiOptionalListOfOptionalVariant +// ============================================================================= + +absl::StatusOr>>>> +LiOptionalListOfOptionalVariant::GetVariants(const nlohmann::json& json) { + auto variants_it = json.find("variants"); + if (variants_it == json.end()) { + return std::nullopt; + } + const nlohmann::json& json_variants = variants_it.value(); + + if (json_variants.is_null()) { + return absl::InvalidArgumentError("json_variants is null."); + } + if (!json_variants.is_array()) { + return absl::InvalidArgumentError("json_variants expected to be array."); + } + + std::vector>> variants; + for (const nlohmann::json& json_variants_element : json_variants) { + std::optional> variants_element; + if (!json_variants_element.is_null()) { + if (json_variants_element.is_boolean()) { + variants_element = json_variants_element.get(); + } else if (json_variants_element.is_string()) { + variants_element = json_variants_element.get(); + } else { + auto result = absl::InvalidArgumentError("json_variants_element has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_variants_element.dump()}); + return result; + } + } + variants.push_back(std::move(variants_element)); + } + return variants; +} + +absl::StatusOr> +LiOptionalListOfOptionalVariant::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto variants, LiOptionalListOfOptionalVariant::GetVariants(json)); + + return absl::make_unique( + std::move(variants)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/list/ast_gen_test.cc b/maldoca/astgen/test/list/ast_gen_test.cc new file mode 100644 index 0000000..9eed538 --- /dev/null +++ b/maldoca/astgen/test/list/ast_gen_test.cc @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + List, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = + "maldoca/astgen/test/list/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "list/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/list", + .ir_path = "maldoca/astgen/test/list", + .expected_ast_header_path = + "maldoca/astgen/test/list/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/list/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "list/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "list/ast_from_json.generated.cc", + .expected_ir_tablegen_path = + "maldoca/astgen/test/" + "list/liir_ops.generated.td", + .expected_ast_to_ir_source_path = + "maldoca/astgen/test/" + "list/conversion/ast_to_liir.generated.cc", + .expected_ir_to_ast_source_path = + "maldoca/astgen/test/" + "list/conversion/liir_to_ast.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/list/ast_to_json.generated.cc b/maldoca/astgen/test/list/ast_to_json.generated.cc new file mode 100644 index 0000000..3a6eef9 --- /dev/null +++ b/maldoca/astgen/test/list/ast_to_json.generated.cc @@ -0,0 +1,428 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/list/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// LiClass1 +// ============================================================================= + +void LiClass1::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +void LiClass1::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiClass1::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiClass2 +// ============================================================================= + +void LiClass2::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +void LiClass2::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiClass2::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiSimpleList +// ============================================================================= + +void LiSimpleList::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"strings\":" << "["; + { + bool needs_comma = false; + for (const auto& element : strings_) { + MaybeAddComma(os, needs_comma); + os << (nlohmann::json(element)).dump(); + } + } + os << "]"; + MaybeAddComma(os, needs_comma); + os << "\"operations\":" << "["; + { + bool needs_comma = false; + for (const auto& element : operations_) { + MaybeAddComma(os, needs_comma); + element->Serialize(os); + } + } + os << "]"; +} + +void LiSimpleList::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiSimpleList::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiOptionalList +// ============================================================================= + +void LiOptionalList::SerializeFields(std::ostream& os, bool &needs_comma) const { + if (strings_.has_value()) { + MaybeAddComma(os, needs_comma); + os << "\"strings\":" << "["; + { + bool needs_comma = false; + for (const auto& element : strings_.value()) { + MaybeAddComma(os, needs_comma); + os << (nlohmann::json(element)).dump(); + } + } + os << "]"; + } +} + +void LiOptionalList::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiOptionalList::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiListOfOptional +// ============================================================================= + +void LiListOfOptional::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"strings\":" << "["; + { + bool needs_comma = false; + for (const auto& element : strings_) { + MaybeAddComma(os, needs_comma); + if (element.has_value()) { + os << (nlohmann::json(element.value())).dump(); + } else { + os << "null"; + } + } + } + os << "]"; + MaybeAddComma(os, needs_comma); + os << "\"operations\":" << "["; + { + bool needs_comma = false; + for (const auto& element : operations_) { + MaybeAddComma(os, needs_comma); + if (element.has_value()) { + element.value()->Serialize(os); + } else { + os << "null"; + } + } + } + os << "]"; +} + +void LiListOfOptional::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiListOfOptional::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiListOfVariant +// ============================================================================= + +void LiListOfVariant::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"variants\":" << "["; + { + bool needs_comma = false; + for (const auto& element : variants_) { + MaybeAddComma(os, needs_comma); + switch (element.index()) { + case 0: { + os << (nlohmann::json(std::get<0>(element))).dump(); + break; + } + case 1: { + os << (nlohmann::json(std::get<1>(element))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + } + os << "]"; + MaybeAddComma(os, needs_comma); + os << "\"operations\":" << "["; + { + bool needs_comma = false; + for (const auto& element : operations_) { + MaybeAddComma(os, needs_comma); + switch (element.index()) { + case 0: { + std::get<0>(element)->Serialize(os); + break; + } + case 1: { + std::get<1>(element)->Serialize(os); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + } + os << "]"; +} + +void LiListOfVariant::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiListOfVariant::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiOptionalListOfOptional +// ============================================================================= + +void LiOptionalListOfOptional::SerializeFields(std::ostream& os, bool &needs_comma) const { + if (variants_.has_value()) { + MaybeAddComma(os, needs_comma); + os << "\"variants\":" << "["; + { + bool needs_comma = false; + for (const auto& element : variants_.value()) { + MaybeAddComma(os, needs_comma); + if (element.has_value()) { + os << (nlohmann::json(element.value())).dump(); + } else { + os << "null"; + } + } + } + os << "]"; + } +} + +void LiOptionalListOfOptional::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiOptionalListOfOptional::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiOptionalListOfVariant +// ============================================================================= + +void LiOptionalListOfVariant::SerializeFields(std::ostream& os, bool &needs_comma) const { + if (variants_.has_value()) { + MaybeAddComma(os, needs_comma); + os << "\"variants\":" << "["; + { + bool needs_comma = false; + for (const auto& element : variants_.value()) { + MaybeAddComma(os, needs_comma); + switch (element.index()) { + case 0: { + os << (nlohmann::json(std::get<0>(element))).dump(); + break; + } + case 1: { + os << (nlohmann::json(std::get<1>(element))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + } + os << "]"; + } +} + +void LiOptionalListOfVariant::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiOptionalListOfVariant::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiListOfOptionalVariant +// ============================================================================= + +void LiListOfOptionalVariant::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"variants\":" << "["; + { + bool needs_comma = false; + for (const auto& element : variants_) { + MaybeAddComma(os, needs_comma); + if (element.has_value()) { + switch (element.value().index()) { + case 0: { + os << (nlohmann::json(std::get<0>(element.value()))).dump(); + break; + } + case 1: { + os << (nlohmann::json(std::get<1>(element.value()))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } else { + os << "null"; + } + } + } + os << "]"; + MaybeAddComma(os, needs_comma); + os << "\"operations\":" << "["; + { + bool needs_comma = false; + for (const auto& element : operations_) { + MaybeAddComma(os, needs_comma); + if (element.has_value()) { + switch (element.value().index()) { + case 0: { + std::get<0>(element.value())->Serialize(os); + break; + } + case 1: { + std::get<1>(element.value())->Serialize(os); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } else { + os << "null"; + } + } + } + os << "]"; +} + +void LiListOfOptionalVariant::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiListOfOptionalVariant::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// LiOptionalListOfOptionalVariant +// ============================================================================= + +void LiOptionalListOfOptionalVariant::SerializeFields(std::ostream& os, bool &needs_comma) const { + if (variants_.has_value()) { + MaybeAddComma(os, needs_comma); + os << "\"variants\":" << "["; + { + bool needs_comma = false; + for (const auto& element : variants_.value()) { + MaybeAddComma(os, needs_comma); + if (element.has_value()) { + switch (element.value().index()) { + case 0: { + os << (nlohmann::json(std::get<0>(element.value()))).dump(); + break; + } + case 1: { + os << (nlohmann::json(std::get<1>(element.value()))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } else { + os << "null"; + } + } + } + os << "]"; + } +} + +void LiOptionalListOfOptionalVariant::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + LiOptionalListOfOptionalVariant::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/list/ast_ts_interface.generated b/maldoca/astgen/test/list/ast_ts_interface.generated new file mode 100644 index 0000000..89e3ca5 --- /dev/null +++ b/maldoca/astgen/test/list/ast_ts_interface.generated @@ -0,0 +1,41 @@ +interface Class1 { +} + +interface Class2 { +} + +interface SimpleList { + strings: [ string ] + operations: [ Class1 ] +} + +interface OptionalList { + strings?: [ string ] +} + +interface ListOfOptional { + strings: [ string | null ] + operations: [ Class1 | null ] +} + +interface ListOfVariant { + variants: [ boolean | string ] + operations: [ Class1 | Class2 ] +} + +interface OptionalListOfOptional { + variants?: [ string | null ] +} + +interface OptionalListOfVariant { + variants?: [ boolean | string ] +} + +interface ListOfOptionalVariant { + variants: [ boolean | string | null ] + operations: [ Class1 | Class2 | null ] +} + +interface OptionalListOfOptionalVariant { + variants?: [ boolean | string | null ] +} diff --git a/maldoca/astgen/test/list/conversion/BUILD b/maldoca/astgen/test/list/conversion/BUILD new file mode 100644 index 0000000..9a4e1a9 --- /dev/null +++ b/maldoca/astgen/test/list/conversion/BUILD @@ -0,0 +1,77 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_applicable_licenses = ["//:license"]) + +licenses(["notice"]) + +exports_files([ + "ast_to_liir.generated.cc", + "liir_to_ast.generated.cc", +]) + +cc_library( + name = "ast_to_liir", + srcs = ["ast_to_liir.generated.cc"], + hdrs = ["ast_to_liir.h"], + deps = [ + "//maldoca/astgen/test/list:ast", + "//maldoca/astgen/test/list:ir", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "liir_to_ast", + srcs = ["liir_to_ast.generated.cc"], + hdrs = ["liir_to_ast.h"], + deps = [ + "//maldoca/astgen/test/list:ast", + "//maldoca/astgen/test/list:ir", + "//maldoca/base:status", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_test( + name = "conversion_test", + srcs = ["conversion_test.cc"], + deps = [ + ":ast_to_liir", + ":liir_to_ast", + "//maldoca/astgen/test:conversion_test_util", + "//maldoca/astgen/test/list:ast", + "//maldoca/astgen/test/list:ir", + "@googletest//:gtest_main", + ], +) diff --git a/maldoca/astgen/test/list/conversion/ast_to_liir.generated.cc b/maldoca/astgen/test/list/conversion/ast_to_liir.generated.cc new file mode 100644 index 0000000..db647d9 --- /dev/null +++ b/maldoca/astgen/test/list/conversion/ast_to_liir.generated.cc @@ -0,0 +1,263 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/list/conversion/ast_to_liir.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/astgen/test/list/ast.generated.h" +#include "maldoca/astgen/test/list/ir.h" + +namespace maldoca { + +LiirClass1Op AstToLiir::VisitClass1(const LiClass1 *node) { + return CreateExpr(node); +} + +LiirClass2Op AstToLiir::VisitClass2(const LiClass2 *node) { + return CreateExpr(node); +} + +LiirSimpleListOp AstToLiir::VisitSimpleList(const LiSimpleList *node) { + std::vector mlir_strings_data; + for (const auto &element : *node->strings()) { + mlir::StringAttr mlir_element = builder_.getStringAttr(element); + mlir_strings_data.push_back(std::move(mlir_element)); + } + auto mlir_strings = builder_.getArrayAttr(mlir_strings_data); + std::vector mlir_operations; + for (const auto &element : *node->operations()) { + mlir::Value mlir_element = VisitClass1(element.get()); + mlir_operations.push_back(std::move(mlir_element)); + } + return CreateExpr(node, mlir_strings, mlir_operations); +} + +LiirOptionalListOp AstToLiir::VisitOptionalList(const LiOptionalList *node) { + mlir::ArrayAttr mlir_strings; + if (node->strings().has_value()) { + std::vector mlir_strings_data; + for (const auto &element : *node->strings().value()) { + mlir::StringAttr mlir_element = builder_.getStringAttr(element); + mlir_strings_data.push_back(std::move(mlir_element)); + } + mlir_strings = builder_.getArrayAttr(mlir_strings_data); + } + return CreateExpr(node, mlir_strings); +} + +LiirListOfOptionalOp AstToLiir::VisitListOfOptional(const LiListOfOptional *node) { + std::vector mlir_strings_data; + for (const auto &element : *node->strings()) { + mlir::StringAttr mlir_element; + if (element.has_value()) { + mlir_element = builder_.getStringAttr(element.value()); + } + mlir_strings_data.push_back(std::move(mlir_element)); + } + auto mlir_strings = builder_.getArrayAttr(mlir_strings_data); + std::vector mlir_operations; + for (const auto &element : *node->operations()) { + mlir::Value mlir_element; + if (element.has_value()) { + mlir_element = VisitClass1(element.value().get()); + } else { + mlir_element = CreateExpr(node); + } + mlir_operations.push_back(std::move(mlir_element)); + } + return CreateExpr(node, mlir_strings, mlir_operations); +} + +LiirListOfVariantOp AstToLiir::VisitListOfVariant(const LiListOfVariant *node) { + std::vector mlir_variants_data; + for (const auto &element : *node->variants()) { + mlir::Attribute mlir_element; + switch (element.index()) { + case 0: { + mlir_element = builder_.getBoolAttr(std::get<0>(element)); + break; + } + case 1: { + mlir_element = builder_.getStringAttr(std::get<1>(element)); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + mlir_variants_data.push_back(std::move(mlir_element)); + } + auto mlir_variants = builder_.getArrayAttr(mlir_variants_data); + std::vector mlir_operations; + for (const auto &element : *node->operations()) { + mlir::Value mlir_element; + switch (element.index()) { + case 0: { + mlir_element = VisitClass1(std::get<0>(element).get()); + break; + } + case 1: { + mlir_element = VisitClass2(std::get<1>(element).get()); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + mlir_operations.push_back(std::move(mlir_element)); + } + return CreateExpr(node, mlir_variants, mlir_operations); +} + +LiirOptionalListOfOptionalOp AstToLiir::VisitOptionalListOfOptional(const LiOptionalListOfOptional *node) { + mlir::ArrayAttr mlir_variants; + if (node->variants().has_value()) { + std::vector mlir_variants_data; + for (const auto &element : *node->variants().value()) { + mlir::StringAttr mlir_element; + if (element.has_value()) { + mlir_element = builder_.getStringAttr(element.value()); + } + mlir_variants_data.push_back(std::move(mlir_element)); + } + mlir_variants = builder_.getArrayAttr(mlir_variants_data); + } + return CreateExpr(node, mlir_variants); +} + +LiirOptionalListOfVariantOp AstToLiir::VisitOptionalListOfVariant(const LiOptionalListOfVariant *node) { + mlir::ArrayAttr mlir_variants; + if (node->variants().has_value()) { + std::vector mlir_variants_data; + for (const auto &element : *node->variants().value()) { + mlir::Attribute mlir_element; + switch (element.index()) { + case 0: { + mlir_element = builder_.getBoolAttr(std::get<0>(element)); + break; + } + case 1: { + mlir_element = builder_.getStringAttr(std::get<1>(element)); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + mlir_variants_data.push_back(std::move(mlir_element)); + } + mlir_variants = builder_.getArrayAttr(mlir_variants_data); + } + return CreateExpr(node, mlir_variants); +} + +LiirListOfOptionalVariantOp AstToLiir::VisitListOfOptionalVariant(const LiListOfOptionalVariant *node) { + std::vector mlir_variants_data; + for (const auto &element : *node->variants()) { + mlir::Attribute mlir_element; + if (element.has_value()) { + switch (element.value().index()) { + case 0: { + mlir_element = builder_.getBoolAttr(std::get<0>(element.value())); + break; + } + case 1: { + mlir_element = builder_.getStringAttr(std::get<1>(element.value())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + mlir_variants_data.push_back(std::move(mlir_element)); + } + auto mlir_variants = builder_.getArrayAttr(mlir_variants_data); + std::vector mlir_operations; + for (const auto &element : *node->operations()) { + mlir::Value mlir_element; + if (element.has_value()) { + switch (element.value().index()) { + case 0: { + mlir_element = VisitClass1(std::get<0>(element.value()).get()); + break; + } + case 1: { + mlir_element = VisitClass2(std::get<1>(element.value()).get()); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } else { + mlir_element = CreateExpr(node); + } + mlir_operations.push_back(std::move(mlir_element)); + } + return CreateExpr(node, mlir_variants, mlir_operations); +} + +LiirOptionalListOfOptionalVariantOp AstToLiir::VisitOptionalListOfOptionalVariant(const LiOptionalListOfOptionalVariant *node) { + mlir::ArrayAttr mlir_variants; + if (node->variants().has_value()) { + std::vector mlir_variants_data; + for (const auto &element : *node->variants().value()) { + mlir::Attribute mlir_element; + if (element.has_value()) { + switch (element.value().index()) { + case 0: { + mlir_element = builder_.getBoolAttr(std::get<0>(element.value())); + break; + } + case 1: { + mlir_element = builder_.getStringAttr(std::get<1>(element.value())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + mlir_variants_data.push_back(std::move(mlir_element)); + } + mlir_variants = builder_.getArrayAttr(mlir_variants_data); + } + return CreateExpr(node, mlir_variants); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/list/conversion/ast_to_liir.h b/maldoca/astgen/test/list/conversion/ast_to_liir.h new file mode 100644 index 0000000..570016d --- /dev/null +++ b/maldoca/astgen/test/list/conversion/ast_to_liir.h @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LIST_CONVERSION_AST_TO_LIIR_H_ +#define MALDOCA_ASTGEN_TEST_LIST_CONVERSION_AST_TO_LIIR_H_ + +#include "mlir/IR/Builders.h" +#include "maldoca/astgen/test/list/ast.generated.h" +#include "maldoca/astgen/test/list/ir.h" + +namespace maldoca { + +class AstToLiir { + public: + explicit AstToLiir(mlir::OpBuilder &builder) : builder_(builder) {} + + LiirClass1Op VisitClass1(const LiClass1 *node); + + LiirClass2Op VisitClass2(const LiClass2 *node); + + LiirSimpleListOp VisitSimpleList(const LiSimpleList *node); + + LiirOptionalListOp VisitOptionalList(const LiOptionalList *node); + + LiirListOfOptionalOp VisitListOfOptional(const LiListOfOptional *node); + + LiirListOfVariantOp VisitListOfVariant(const LiListOfVariant *node); + + LiirOptionalListOfOptionalOp VisitOptionalListOfOptional( + const LiOptionalListOfOptional *node); + + LiirOptionalListOfVariantOp VisitOptionalListOfVariant( + const LiOptionalListOfVariant *node); + + LiirListOfOptionalVariantOp VisitListOfOptionalVariant( + const LiListOfOptionalVariant *node); + + LiirOptionalListOfOptionalVariantOp VisitOptionalListOfOptionalVariant( + const LiOptionalListOfOptionalVariant *node); + + private: + template + Op CreateExpr(const Node *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), + std::forward(args)...); + } + + mlir::OpBuilder &builder_; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_LIST_CONVERSION_AST_TO_LIIR_H_ diff --git a/maldoca/astgen/test/list/conversion/conversion_test.cc b/maldoca/astgen/test/list/conversion/conversion_test.cc new file mode 100644 index 0000000..073c790 --- /dev/null +++ b/maldoca/astgen/test/list/conversion/conversion_test.cc @@ -0,0 +1,127 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/conversion_test_util.h" +#include "maldoca/astgen/test/list/ast.generated.h" +#include "maldoca/astgen/test/list/conversion/ast_to_liir.h" +#include "maldoca/astgen/test/list/conversion/liir_to_ast.h" +#include "maldoca/astgen/test/list/ir.h" + +namespace maldoca { +namespace { + +TEST(ConversionTest, OptionalListWithNullopt) { + constexpr char kAstJsonString[] = R"( + {} + )"; + + constexpr char kExpectedIr[] = R"( +module { + %0 = "liir.optional_list"() : () -> !liir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToLiir::VisitOptionalList, + .ir_to_ast_visit = &LiirToAst::VisitOptionalList, + .expected_ir_dump = kExpectedIr, + }); +} + +TEST(ConversionTest, OptionalList) { + constexpr char kAstJsonString[] = R"( + { + "strings": [ + "a", + "b" + ] + } + )"; + + constexpr char kExpectedIr[] = R"( +module { + %0 = "liir.optional_list"() <{strings = ["a", "b"]}> : () -> !liir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToLiir::VisitOptionalList, + .ir_to_ast_visit = &LiirToAst::VisitOptionalList, + .expected_ir_dump = kExpectedIr, + }); +} + +TEST(ConversionTest, ListOfVariant) { + constexpr char kAstJsonString[] = R"( + { + "variants": [ + true, + "true" + ], + "operations": [ + {} + ] + } + )"; + + constexpr char kExpectedIr[] = R"( +module { + %0 = "liir.class1"() : () -> !liir.any + %1 = "liir.list_of_variant"(%0) <{variants = [true, "true"]}> : (!liir.any) -> !liir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToLiir::VisitListOfVariant, + .ir_to_ast_visit = &LiirToAst::VisitListOfVariant, + .expected_ir_dump = kExpectedIr, + }); +} + +// Disabled because null attributes are not supported. +TEST(ConversionTest, DISABLED_ListOfOptional) { + constexpr char kAstJsonString[] = R"( + { + "strings": [ + "true", + null + ], + "operations": [] + } + )"; + + constexpr char kExpectedIr[] = R"( +"builtin.module"() ({ + %0 = "liir.list_of_optional"() <{strings = ["true", <>]}> : () -> !liir.any +}) : () -> () + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToLiir::VisitListOfOptional, + .ir_to_ast_visit = &LiirToAst::VisitListOfOptional, + .expected_ir_dump = kExpectedIr, + }); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/list/conversion/liir_to_ast.generated.cc b/maldoca/astgen/test/list/conversion/liir_to_ast.generated.cc new file mode 100644 index 0000000..9490a9d --- /dev/null +++ b/maldoca/astgen/test/list/conversion/liir_to_ast.generated.cc @@ -0,0 +1,294 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/list/conversion/liir_to_ast.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/base/status_macros.h" +#include "maldoca/astgen/test/list/ast.generated.h" +#include "maldoca/astgen/test/list/ir.h" + +namespace maldoca { + +absl::StatusOr> +LiirToAst::VisitClass1(LiirClass1Op op) { + return Create( + op); +} + +absl::StatusOr> +LiirToAst::VisitClass2(LiirClass2Op op) { + return Create( + op); +} + +absl::StatusOr> +LiirToAst::VisitSimpleList(LiirSimpleListOp op) { + std::vector strings; + for (mlir::Attribute mlir_strings_element_unchecked : op.getStringsAttr().getValue()) { + auto strings_element_attr = llvm::dyn_cast(mlir_strings_element_unchecked); + if (strings_element_attr == nullptr) { + return absl::InvalidArgumentError("Invalid attribute."); + } + std::string strings_element = strings_element_attr.str(); + strings.push_back(std::move(strings_element)); + } + std::vector> operations; + for (mlir::Value mlir_operations_element_unchecked : op.getOperations()) { + auto operations_element_op = llvm::dyn_cast(mlir_operations_element_unchecked.getDefiningOp()); + if (operations_element_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected LiirClass1Op, got ", + mlir_operations_element_unchecked.getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr operations_element, VisitClass1(operations_element_op)); + operations.push_back(std::move(operations_element)); + } + return Create( + op, + std::move(strings), + std::move(operations)); +} + +absl::StatusOr> +LiirToAst::VisitOptionalList(LiirOptionalListOp op) { + std::optional> strings; + if (op.getStringsAttr() != nullptr) { + std::vector strings_value; + for (mlir::Attribute mlir_strings_element_unchecked : op.getStringsAttr().getValue()) { + auto strings_element_attr = llvm::dyn_cast(mlir_strings_element_unchecked); + if (strings_element_attr == nullptr) { + return absl::InvalidArgumentError("Invalid attribute."); + } + std::string strings_element = strings_element_attr.str(); + strings_value.push_back(std::move(strings_element)); + } + strings = std::move(strings_value); + } + return Create( + op, + std::move(strings)); +} + +absl::StatusOr> +LiirToAst::VisitListOfOptional(LiirListOfOptionalOp op) { + std::vector> strings; + for (mlir::Attribute mlir_strings_element_unchecked : op.getStringsAttr().getValue()) { + std::optional strings_element; + if (mlir_strings_element_unchecked != nullptr) { + auto strings_element_attr = llvm::dyn_cast(mlir_strings_element_unchecked); + if (strings_element_attr == nullptr) { + return absl::InvalidArgumentError("Invalid attribute."); + } + strings_element = strings_element_attr.str(); + } + strings.push_back(std::move(strings_element)); + } + std::vector>> operations; + for (mlir::Value mlir_operations_element_unchecked : op.getOperations()) { + std::optional> operations_element; + if (!llvm::isa(mlir_operations_element_unchecked.getDefiningOp())) { + auto operations_element_op = llvm::dyn_cast(mlir_operations_element_unchecked.getDefiningOp()); + if (operations_element_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected LiirClass1Op, got ", + mlir_operations_element_unchecked.getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(operations_element, VisitClass1(operations_element_op)); + } + operations.push_back(std::move(operations_element)); + } + return Create( + op, + std::move(strings), + std::move(operations)); +} + +absl::StatusOr> +LiirToAst::VisitListOfVariant(LiirListOfVariantOp op) { + std::vector> variants; + for (mlir::Attribute mlir_variants_element_unchecked : op.getVariantsAttr().getValue()) { + std::variant variants_element; + if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.getValue(); + } else if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.str(); + } else { + return absl::InvalidArgumentError("mlir_variants_element_unchecked has invalid type."); + } + variants.push_back(std::move(variants_element)); + } + std::vector, std::unique_ptr>> operations; + for (mlir::Value mlir_operations_element_unchecked : op.getOperations()) { + std::variant, std::unique_ptr> operations_element; + if (auto mlir_operations_element = llvm::dyn_cast(mlir_operations_element_unchecked.getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, VisitClass1(mlir_operations_element)); + } else if (auto mlir_operations_element = llvm::dyn_cast(mlir_operations_element_unchecked.getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, VisitClass2(mlir_operations_element)); + } else { + return absl::InvalidArgumentError("mlir_operations_element_unchecked.getDefiningOp() has invalid type."); + } + operations.push_back(std::move(operations_element)); + } + return Create( + op, + std::move(variants), + std::move(operations)); +} + +absl::StatusOr> +LiirToAst::VisitOptionalListOfOptional(LiirOptionalListOfOptionalOp op) { + std::optional>> variants; + if (op.getVariantsAttr() != nullptr) { + std::vector> variants_value; + for (mlir::Attribute mlir_variants_element_unchecked : op.getVariantsAttr().getValue()) { + std::optional variants_element; + if (mlir_variants_element_unchecked != nullptr) { + auto variants_element_attr = llvm::dyn_cast(mlir_variants_element_unchecked); + if (variants_element_attr == nullptr) { + return absl::InvalidArgumentError("Invalid attribute."); + } + variants_element = variants_element_attr.str(); + } + variants_value.push_back(std::move(variants_element)); + } + variants = std::move(variants_value); + } + return Create( + op, + std::move(variants)); +} + +absl::StatusOr> +LiirToAst::VisitOptionalListOfVariant(LiirOptionalListOfVariantOp op) { + std::optional>> variants; + if (op.getVariantsAttr() != nullptr) { + std::vector> variants_value; + for (mlir::Attribute mlir_variants_element_unchecked : op.getVariantsAttr().getValue()) { + std::variant variants_element; + if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.getValue(); + } else if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.str(); + } else { + return absl::InvalidArgumentError("mlir_variants_element_unchecked has invalid type."); + } + variants_value.push_back(std::move(variants_element)); + } + variants = std::move(variants_value); + } + return Create( + op, + std::move(variants)); +} + +absl::StatusOr> +LiirToAst::VisitListOfOptionalVariant(LiirListOfOptionalVariantOp op) { + std::vector>> variants; + for (mlir::Attribute mlir_variants_element_unchecked : op.getVariantsAttr().getValue()) { + std::optional> variants_element; + if (mlir_variants_element_unchecked != nullptr) { + if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.getValue(); + } else if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.str(); + } else { + return absl::InvalidArgumentError("mlir_variants_element_unchecked has invalid type."); + } + } + variants.push_back(std::move(variants_element)); + } + std::vector, std::unique_ptr>>> operations; + for (mlir::Value mlir_operations_element_unchecked : op.getOperations()) { + std::optional, std::unique_ptr>> operations_element; + if (!llvm::isa(mlir_operations_element_unchecked.getDefiningOp())) { + if (auto mlir_operations_element = llvm::dyn_cast(mlir_operations_element_unchecked.getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, VisitClass1(mlir_operations_element)); + } else if (auto mlir_operations_element = llvm::dyn_cast(mlir_operations_element_unchecked.getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(operations_element, VisitClass2(mlir_operations_element)); + } else { + return absl::InvalidArgumentError("mlir_operations_element_unchecked.getDefiningOp() has invalid type."); + } + } + operations.push_back(std::move(operations_element)); + } + return Create( + op, + std::move(variants), + std::move(operations)); +} + +absl::StatusOr> +LiirToAst::VisitOptionalListOfOptionalVariant(LiirOptionalListOfOptionalVariantOp op) { + std::optional>>> variants; + if (op.getVariantsAttr() != nullptr) { + std::vector>> variants_value; + for (mlir::Attribute mlir_variants_element_unchecked : op.getVariantsAttr().getValue()) { + std::optional> variants_element; + if (mlir_variants_element_unchecked != nullptr) { + if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.getValue(); + } else if (auto mlir_variants_element = llvm::dyn_cast(mlir_variants_element_unchecked)) { + variants_element = mlir_variants_element.str(); + } else { + return absl::InvalidArgumentError("mlir_variants_element_unchecked has invalid type."); + } + } + variants_value.push_back(std::move(variants_element)); + } + variants = std::move(variants_value); + } + return Create( + op, + std::move(variants)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/list/conversion/liir_to_ast.h b/maldoca/astgen/test/list/conversion/liir_to_ast.h new file mode 100644 index 0000000..b7a538a --- /dev/null +++ b/maldoca/astgen/test/list/conversion/liir_to_ast.h @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LIST_CONVERSION_LIIR_TO_AST_H_ +#define MALDOCA_ASTGEN_TEST_LIST_CONVERSION_LIIR_TO_AST_H_ + +#include + +#include "mlir/IR/Operation.h" +#include "absl/status/statusor.h" +#include "maldoca/astgen/test/list/ast.generated.h" +#include "maldoca/astgen/test/list/ir.h" + +namespace maldoca { + +class LiirToAst { + public: + absl::StatusOr> VisitClass1(LiirClass1Op op); + + absl::StatusOr> VisitClass2(LiirClass2Op op); + + absl::StatusOr> VisitSimpleList( + LiirSimpleListOp op); + + absl::StatusOr> VisitOptionalList( + LiirOptionalListOp op); + + absl::StatusOr> VisitListOfOptional( + LiirListOfOptionalOp op); + + absl::StatusOr> VisitListOfVariant( + LiirListOfVariantOp op); + + absl::StatusOr> + VisitOptionalListOfOptional(LiirOptionalListOfOptionalOp op); + + absl::StatusOr> + VisitOptionalListOfVariant(LiirOptionalListOfVariantOp op); + + absl::StatusOr> + VisitListOfOptionalVariant(LiirListOfOptionalVariantOp op); + + absl::StatusOr> + VisitOptionalListOfOptionalVariant(LiirOptionalListOfOptionalVariantOp op); + + private: + template + std::unique_ptr Create(mlir::Operation *op, Args &&...args) { + return absl::make_unique(std::forward(args)...); + } +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_LIST_CONVERSION_LIIR_TO_AST_H_ diff --git a/maldoca/astgen/test/list/interfaces.td b/maldoca/astgen/test/list/interfaces.td new file mode 100644 index 0000000..5d18117 --- /dev/null +++ b/maldoca/astgen/test/list/interfaces.td @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "mlir/IR/OpBase.td" diff --git a/maldoca/astgen/test/list/ir.cc b/maldoca/astgen/test/list/ir.cc new file mode 100644 index 0000000..e31501e --- /dev/null +++ b/maldoca/astgen/test/list/ir.cc @@ -0,0 +1,63 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/test/list/ir.h" + +// IWYU pragma: begin_keep + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" + +// IWYU pragma: end_keep + +// ============================================================================= +// Dialect Definition +// ============================================================================= + +#include "maldoca/astgen/test/list/liir_dialect.cc.inc" + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void maldoca::LiirDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "maldoca/astgen/test/list/liir_types.cc.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "maldoca/astgen/test/list/liir_ops.cc.inc" + >(); +} + +// ============================================================================= +// Dialect Interface Definitions +// ============================================================================= + +#include "maldoca/astgen/test/list/interfaces.cc.inc" + +// ============================================================================= +// Dialect Type Definitions +// ============================================================================= + +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/list/liir_types.cc.inc" + +// ============================================================================= +// Dialect Op Definitions +// ============================================================================= + +#define GET_OP_CLASSES +#include "maldoca/astgen/test/list/liir_ops.cc.inc" diff --git a/maldoca/astgen/test/list/ir.h b/maldoca/astgen/test/list/ir.h new file mode 100644 index 0000000..bca3081 --- /dev/null +++ b/maldoca/astgen/test/list/ir.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LIST_IR_H_ +#define MALDOCA_ASTGEN_TEST_LIST_IR_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +// Include the auto-generated header file containing the declaration of the LIIR +// dialect. +#include "maldoca/astgen/test/list/liir_dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// LIIR interfaces. +#include "maldoca/astgen/test/list/interfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the +// LIIR types. +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/list/liir_types.h.inc" + +// Include the auto-generated header file containing the declarations of the +// LIIR operations. +#define GET_OP_CLASSES +#include "maldoca/astgen/test/list/liir_ops.h.inc" + +#endif // MALDOCA_ASTGEN_TEST_LIST_IR_H_ diff --git a/maldoca/astgen/test/list/liir_dialect.td b/maldoca/astgen/test/list/liir_dialect.td new file mode 100644 index 0000000..f049579 --- /dev/null +++ b/maldoca/astgen/test/list/liir_dialect.td @@ -0,0 +1,42 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LIST_LIR_DIALECT_TD_ +#define MALDOCA_ASTGEN_TEST_LIST_LIR_DIALECT_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def Liir_Dialect : Dialect { + let name = "liir"; + let cppNamespace = "::maldoca"; + + let description = [{ + The ListIR, a test IR that makes extensive use of list types. All ops and + fields are directly mapped from the AST. + }]; + + let useDefaultTypePrinterParser = 1; +} + +class Liir_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = ?; +} + +class Liir_Op traits = []> : + Op; + +#endif // MALDOCA_ASTGEN_TEST_LIST_LIR_DIALECT_TD_ diff --git a/maldoca/astgen/test/list/liir_ops.generated.td b/maldoca/astgen/test/list/liir_ops.generated.td new file mode 100644 index 0000000..002cfaa --- /dev/null +++ b/maldoca/astgen/test/list/liir_ops.generated.td @@ -0,0 +1,128 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_LIST_LIIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_LIST_LIIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/list/interfaces.td" +include "maldoca/astgen/test/list/liir_dialect.td" +include "maldoca/astgen/test/list/liir_types.td" + +def LiirClass1Op : Liir_Op<"class1", []> { + let results = (outs + LiirAnyType + ); +} + +def LiirClass2Op : Liir_Op<"class2", []> { + let results = (outs + LiirAnyType + ); +} + +def LiirSimpleListOp : Liir_Op<"simple_list", []> { + let arguments = (ins + TypedArrayAttrBase: $strings, + Variadic: $operations + ); + + let results = (outs + LiirAnyType + ); +} + +def LiirOptionalListOp : Liir_Op<"optional_list", []> { + let arguments = (ins + OptionalAttr>: $strings + ); + + let results = (outs + LiirAnyType + ); +} + +def LiirListOfOptionalOp : Liir_Op<"list_of_optional", []> { + let arguments = (ins + TypedArrayAttrBase, "">: $strings, + Variadic: $operations + ); + + let results = (outs + LiirAnyType + ); +} + +def LiirListOfVariantOp : Liir_Op<"list_of_variant", []> { + let arguments = (ins + TypedArrayAttrBase, "">: $variants, + Variadic: $operations + ); + + let results = (outs + LiirAnyType + ); +} + +def LiirOptionalListOfOptionalOp : Liir_Op<"optional_list_of_optional", []> { + let arguments = (ins + OptionalAttr, "">>: $variants + ); + + let results = (outs + LiirAnyType + ); +} + +def LiirOptionalListOfVariantOp : Liir_Op<"optional_list_of_variant", []> { + let arguments = (ins + OptionalAttr, "">>: $variants + ); + + let results = (outs + LiirAnyType + ); +} + +def LiirListOfOptionalVariantOp : Liir_Op<"list_of_optional_variant", []> { + let arguments = (ins + TypedArrayAttrBase>, "">: $variants, + Variadic: $operations + ); + + let results = (outs + LiirAnyType + ); +} + +def LiirOptionalListOfOptionalVariantOp : Liir_Op<"optional_list_of_optional_variant", []> { + let arguments = (ins + OptionalAttr>, "">>: $variants + ); + + let results = (outs + LiirAnyType + ); +} + +#endif // MALDOCA_ASTGEN_TEST_LIST_LIIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/list/liir_ops.td b/maldoca/astgen/test/list/liir_ops.td new file mode 100644 index 0000000..7f2afc8 --- /dev/null +++ b/maldoca/astgen/test/list/liir_ops.td @@ -0,0 +1,27 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LIST_LIIR_OPS_TD_ +#define MALDOCA_ASTGEN_TEST_LIST_LIIR_OPS_TD_ + +// Import the generated ops. +include "maldoca/astgen/test/list/liir_ops.generated.td" + +def LiirNoneOp : Liir_Op<"none", []> { + let results = (outs + LiirAnyType + ); +} + +#endif // MALDOCA_ASTGEN_TEST_LIST_LIIR_OPS_TD_ diff --git a/maldoca/astgen/test/list/liir_types.td b/maldoca/astgen/test/list/liir_types.td new file mode 100644 index 0000000..b89d2f9 --- /dev/null +++ b/maldoca/astgen/test/list/liir_types.td @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_LIST_LIIR_TYPES_TD_ +#define MALDOCA_ASTGEN_TEST_LIST_LIIR_TYPES_TD_ + +include "maldoca/astgen/test/list/liir_dialect.td" + +def LiirAnyType : Liir_Type<"LiirAny"> { + let summary = "A placeholder singleton type."; + let mnemonic = "any"; + let assemblyFormat = ""; +} + +#endif // MALDOCA_ASTGEN_TEST_LIST_LIIR_TYPES_TD_ diff --git a/maldoca/astgen/test/multiple_inheritance/BUILD b/maldoca/astgen/test/multiple_inheritance/BUILD new file mode 100644 index 0000000..a9e7cf8 --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/BUILD @@ -0,0 +1,58 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package(default_applicable_licenses = ["//:license"]) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = ["ast.generated.h"], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) diff --git a/maldoca/astgen/test/multiple_inheritance/ast.generated.cc b/maldoca/astgen/test/multiple_inheritance/ast.generated.cc new file mode 100644 index 0000000..2bbabc0 --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/ast.generated.cc @@ -0,0 +1,162 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/multiple_inheritance/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// MSourceLocation +// ============================================================================= + +MSourceLocation::MSourceLocation( + double start, + double end) + : start_(std::move(start)), + end_(std::move(end)) {} + +double MSourceLocation::start() const { + return start_; +} + +void MSourceLocation::set_start(double start) { + start_ = start; +} + +double MSourceLocation::end() const { + return end_; +} + +void MSourceLocation::set_end(double end) { + end_ = end; +} + +// ============================================================================= +// MNode +// ============================================================================= + +absl::string_view MNodeTypeToString(MNodeType node_type) { + switch (node_type) { + case MNodeType::kObjectMethod: + return "ObjectMethod"; + } +} + +absl::StatusOr StringToMNodeType(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"ObjectMethod", MNodeType::kObjectMethod}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for MNodeType: ", s)); + } + return it->second; +} + +MNode::MNode( + std::unique_ptr loc) + : loc_(std::move(loc)) {} + +MSourceLocation* MNode::loc() { + return loc_.get(); +} + +const MSourceLocation* MNode::loc() const { + return loc_.get(); +} + +void MNode::set_loc(std::unique_ptr loc) { + loc_ = std::move(loc); +} + +// ============================================================================= +// MFunction +// ============================================================================= + +MFunction::MFunction( + std::unique_ptr loc, + std::string id) + : MNode(std::move(loc)), + id_(std::move(id)) {} + +absl::string_view MFunction::id() const { + return id_; +} + +void MFunction::set_id(std::string id) { + id_ = std::move(id); +} + +// ============================================================================= +// MObjectMember +// ============================================================================= + +MObjectMember::MObjectMember( + std::unique_ptr loc, + bool computed) + : MNode(std::move(loc)), + computed_(std::move(computed)) {} + +bool MObjectMember::computed() const { + return computed_; +} + +void MObjectMember::set_computed(bool computed) { + computed_ = computed; +} + +// ============================================================================= +// MObjectMethod +// ============================================================================= + +MObjectMethod::MObjectMethod( + std::unique_ptr loc, + bool computed, + std::string id) + : MNode(std::move(loc)), + MObjectMember(std::move(loc), std::move(computed)), + MFunction(std::move(loc), std::move(id)) {} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/multiple_inheritance/ast.generated.h b/maldoca/astgen/test/multiple_inheritance/ast.generated.h new file mode 100644 index 0000000..57fcfd9 --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/ast.generated.h @@ -0,0 +1,185 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_MULTIPLE_INHERITANCE_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_MULTIPLE_INHERITANCE_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +class MSourceLocation { + public: + explicit MSourceLocation( + double start, + double end); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + double start() const; + void set_start(double start); + + double end() const; + void set_end(double end); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetStart(const nlohmann::json& json); + static absl::StatusOr GetEnd(const nlohmann::json& json); + + private: + double start_; + double end_; +}; + +enum class MNodeType { + kObjectMethod, +}; + +absl::string_view MNodeTypeToString(MNodeType node_type); +absl::StatusOr StringToMNodeType(absl::string_view s); + +class MNode { + public: + explicit MNode( + std::unique_ptr loc); + + virtual ~MNode() = default; + + virtual MNodeType node_type() const = 0; + + virtual void Serialize(std::ostream& os) const = 0; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + MSourceLocation* loc(); + const MSourceLocation* loc() const; + void set_loc(std::unique_ptr loc); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetLoc(const nlohmann::json& json); + + private: + std::unique_ptr loc_; +}; + +class MFunction : public virtual MNode { + public: + explicit MFunction( + std::unique_ptr loc, + std::string id); + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + absl::string_view id() const; + void set_id(std::string id); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetId(const nlohmann::json& json); + + private: + std::string id_; +}; + +class MObjectMember : public virtual MNode { + public: + explicit MObjectMember( + std::unique_ptr loc, + bool computed); + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + bool computed() const; + void set_computed(bool computed); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetComputed(const nlohmann::json& json); + + private: + bool computed_; +}; + +class MObjectMethod : public virtual MObjectMember, public virtual MFunction { + public: + explicit MObjectMethod( + std::unique_ptr loc, + bool computed, + std::string id); + + MNodeType node_type() const override { + return MNodeType::kObjectMethod; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_MULTIPLE_INHERITANCE_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/multiple_inheritance/ast_def.textproto b/maldoca/astgen/test/multiple_inheritance/ast_def.textproto new file mode 100644 index 0000000..51bcf0f --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/ast_def.textproto @@ -0,0 +1,65 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "m" + +# interface SourceLocation { +# start: number +# end: number +# } +nodes { + name: "SourceLocation" + fields { + name: "start" + type { double {} } + } + fields { + name: "end" + type { double {} } + } +} + +# interface Node { +# loc: SourceLocation +# } +nodes { + name: "Node" + fields { + name: "loc" + type { class: "SourceLocation" } + } +} + +# interface Function <: Node { +# id: string +# } +nodes { + name: "Function" + parents: "Node" + fields { + name: "id" + type { string {} } + } +} + +# interface ObjectMember <: Node { +# computed: boolean +# } +nodes { + name: "ObjectMember" + parents: "Node" + fields { + name: "computed" + type { bool {} } + } +} + +# interface ObjectMethod <: ObjectMember, Function { +# type: "ObjectMethod" +# } +nodes { + name: "ObjectMethod" + type: "ObjectMethod" + parents: "ObjectMember" + parents: "Function" +} diff --git a/maldoca/astgen/test/multiple_inheritance/ast_from_json.generated.cc b/maldoca/astgen/test/multiple_inheritance/ast_from_json.generated.cc new file mode 100644 index 0000000..e4edfa9 --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/ast_from_json.generated.cc @@ -0,0 +1,240 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/multiple_inheritance/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +static absl::StatusOr GetType(const nlohmann::json& json) { + auto type_it = json.find("type"); + if (type_it == json.end()) { + return absl::InvalidArgumentError("`type` is undefined."); + } + const nlohmann::json& json_type = type_it.value(); + if (json_type.is_null()) { + return absl::InvalidArgumentError("json_type is null."); + } + if (!json_type.is_string()) { + return absl::InvalidArgumentError("`json_type` expected to be string."); + } + return json_type.get(); +} + +// ============================================================================= +// MSourceLocation +// ============================================================================= + +absl::StatusOr +MSourceLocation::GetStart(const nlohmann::json& json) { + auto start_it = json.find("start"); + if (start_it == json.end()) { + return absl::InvalidArgumentError("`start` is undefined."); + } + const nlohmann::json& json_start = start_it.value(); + + if (json_start.is_null()) { + return absl::InvalidArgumentError("json_start is null."); + } + if (!json_start.is_number()) { + return absl::InvalidArgumentError("Expecting json_start.is_number()."); + } + return json_start.get(); +} + +absl::StatusOr +MSourceLocation::GetEnd(const nlohmann::json& json) { + auto end_it = json.find("end"); + if (end_it == json.end()) { + return absl::InvalidArgumentError("`end` is undefined."); + } + const nlohmann::json& json_end = end_it.value(); + + if (json_end.is_null()) { + return absl::InvalidArgumentError("json_end is null."); + } + if (!json_end.is_number()) { + return absl::InvalidArgumentError("Expecting json_end.is_number()."); + } + return json_end.get(); +} + +absl::StatusOr> +MSourceLocation::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto start, MSourceLocation::GetStart(json)); + MALDOCA_ASSIGN_OR_RETURN(auto end, MSourceLocation::GetEnd(json)); + + return absl::make_unique( + std::move(start), + std::move(end)); +} + +// ============================================================================= +// MNode +// ============================================================================= + +absl::StatusOr> +MNode::GetLoc(const nlohmann::json& json) { + auto loc_it = json.find("loc"); + if (loc_it == json.end()) { + return absl::InvalidArgumentError("`loc` is undefined."); + } + const nlohmann::json& json_loc = loc_it.value(); + + if (json_loc.is_null()) { + return absl::InvalidArgumentError("json_loc is null."); + } + return MSourceLocation::FromJson(json_loc); +} + +absl::StatusOr> +MNode::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "ObjectMethod") { + return MObjectMethod::FromJson(json); + } else if (type == "Function") { + return MFunction::FromJson(json); + } else if (type == "ObjectMember") { + return MObjectMember::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// MFunction +// ============================================================================= + +absl::StatusOr +MFunction::GetId(const nlohmann::json& json) { + auto id_it = json.find("id"); + if (id_it == json.end()) { + return absl::InvalidArgumentError("`id` is undefined."); + } + const nlohmann::json& json_id = id_it.value(); + + if (json_id.is_null()) { + return absl::InvalidArgumentError("json_id is null."); + } + if (!json_id.is_string()) { + return absl::InvalidArgumentError("Expecting json_id.is_string()."); + } + return json_id.get(); +} + +absl::StatusOr> +MFunction::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "ObjectMethod") { + return MObjectMethod::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// MObjectMember +// ============================================================================= + +absl::StatusOr +MObjectMember::GetComputed(const nlohmann::json& json) { + auto computed_it = json.find("computed"); + if (computed_it == json.end()) { + return absl::InvalidArgumentError("`computed` is undefined."); + } + const nlohmann::json& json_computed = computed_it.value(); + + if (json_computed.is_null()) { + return absl::InvalidArgumentError("json_computed is null."); + } + if (!json_computed.is_boolean()) { + return absl::InvalidArgumentError("Expecting json_computed.is_boolean()."); + } + return json_computed.get(); +} + +absl::StatusOr> +MObjectMember::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "ObjectMethod") { + return MObjectMethod::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// MObjectMethod +// ============================================================================= + +absl::StatusOr> +MObjectMethod::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto loc, MNode::GetLoc(json)); + MALDOCA_ASSIGN_OR_RETURN(auto computed, MObjectMember::GetComputed(json)); + MALDOCA_ASSIGN_OR_RETURN(auto id, MFunction::GetId(json)); + + return absl::make_unique( + std::move(loc), + std::move(computed), + std::move(id)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/multiple_inheritance/ast_gen_test.cc b/maldoca/astgen/test/multiple_inheritance/ast_gen_test.cc new file mode 100644 index 0000000..4a3268c --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/ast_gen_test.cc @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + MultipleInheritance, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = "maldoca/astgen/test/" + "multiple_inheritance/ast_def.textproto", + .ts_interface_path = + "maldoca/astgen/test/multiple_inheritance/" + "ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/multiple_inheritance", + .expected_ast_header_path = + "maldoca/astgen/test/" + "multiple_inheritance/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/" + "multiple_inheritance/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "multiple_inheritance/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "multiple_inheritance/ast_from_json.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/multiple_inheritance/ast_to_json.generated.cc b/maldoca/astgen/test/multiple_inheritance/ast_to_json.generated.cc new file mode 100644 index 0000000..7b473af --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/ast_to_json.generated.cc @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/multiple_inheritance/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// MSourceLocation +// ============================================================================= + +void MSourceLocation::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"start\":" << (nlohmann::json(start_)).dump(); + MaybeAddComma(os, needs_comma); + os << "\"end\":" << (nlohmann::json(end_)).dump(); +} + +void MSourceLocation::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MSourceLocation::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// MNode +// ============================================================================= + +void MNode::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"loc\":"; + loc_->Serialize(os); +} + +// ============================================================================= +// MFunction +// ============================================================================= + +void MFunction::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"id\":" << (nlohmann::json(id_)).dump(); +} + +// ============================================================================= +// MObjectMember +// ============================================================================= + +void MObjectMember::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"computed\":" << (nlohmann::json(computed_)).dump(); +} + +// ============================================================================= +// MObjectMethod +// ============================================================================= + +void MObjectMethod::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +void MObjectMethod::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"ObjectMethod\""; + MNode::SerializeFields(os, needs_comma); + MObjectMember::SerializeFields(os, needs_comma); + MFunction::SerializeFields(os, needs_comma); + MObjectMethod::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/multiple_inheritance/ast_ts_interface.generated b/maldoca/astgen/test/multiple_inheritance/ast_ts_interface.generated new file mode 100644 index 0000000..7c7aa63 --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/ast_ts_interface.generated @@ -0,0 +1,19 @@ +interface SourceLocation { + start: /*double*/number + end: /*double*/number +} + +interface Node { + loc: SourceLocation +} + +interface Function <: Node { + id: string +} + +interface ObjectMember <: Node { + computed: boolean +} + +interface ObjectMethod <: ObjectMember, Function { +} diff --git a/maldoca/astgen/test/multiple_inheritance/mir_ops.generated.td b/maldoca/astgen/test/multiple_inheritance/mir_ops.generated.td new file mode 100644 index 0000000..59f791b --- /dev/null +++ b/maldoca/astgen/test/multiple_inheritance/mir_ops.generated.td @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_MULTIPLE_INHERITANCE_MIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_MULTIPLE_INHERITANCE_MIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/multiple_inheritance/interfaces.td" +include "maldoca/astgen/test/multiple_inheritance/mir_dialect.td" +include "maldoca/astgen/test/multiple_inheritance/mir_types.td" + +#endif // MALDOCA_ASTGEN_TEST_MULTIPLE_INHERITANCE_MIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/region/BUILD b/maldoca/astgen/test/region/BUILD new file mode 100644 index 0000000..3551e67 --- /dev/null +++ b/maldoca/astgen/test/region/BUILD @@ -0,0 +1,180 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//maldoca/astgen:__subpackages__", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + "rir_ops.generated.td", + "//maldoca/astgen/test/region/conversion:ast_to_rir.generated.cc", + "//maldoca/astgen/test/region/conversion:rir_to_ast.generated.cc", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = ["ast.generated.h"], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) + +td_library( + name = "interfaces_td_files", + srcs = [ + "interfaces.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "interfaces_inc_gen", + tbl_outs = { + "interfaces.h.inc": ["-gen-op-interface-decls"], + "interfaces.cc.inc": ["-gen-op-interface-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "interfaces.td", + deps = [":interfaces_td_files"], +) + +td_library( + name = "rir_dialect_td_files", + srcs = [ + "rir_dialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "rir_dialect_inc_gen", + tbl_outs = { + "rir_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=rir", + ], + "rir_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=rir", + ], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "rir_dialect.td", + deps = [":rir_dialect_td_files"], +) + +td_library( + name = "rir_types_td_files", + srcs = [ + "rir_types.td", + ], + deps = [ + ":rir_dialect_td_files", + ], +) + +gentbl_cc_library( + name = "rir_types_inc_gen", + tbl_outs = { + "rir_types.h.inc": ["-gen-typedef-decls"], + "rir_types.cc.inc": ["-gen-typedef-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "rir_types.td", + deps = [":rir_types_td_files"], +) + +td_library( + name = "rir_ops_td_files", + srcs = [ + "rir_ops.generated.td", + "rir_ops.td", + ], + deps = [ + ":interfaces_td_files", + ":rir_dialect_td_files", + ":rir_types_td_files", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "rir_ops_inc_gen", + tbl_outs = { + "rir_ops.h.inc": ["-gen-op-decls"], + "rir_ops.cc.inc": ["-gen-op-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "rir_ops.td", + deps = [":rir_ops_td_files"], +) + +cc_library( + name = "ir", + srcs = ["ir.cc"], + hdrs = ["ir.h"], + deps = [ + ":interfaces_inc_gen", + ":rir_dialect_inc_gen", + ":rir_ops_inc_gen", + ":rir_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) diff --git a/maldoca/astgen/test/region/ast.generated.cc b/maldoca/astgen/test/region/ast.generated.cc new file mode 100644 index 0000000..5b17d52 --- /dev/null +++ b/maldoca/astgen/test/region/ast.generated.cc @@ -0,0 +1,179 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/region/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// RExpr +// ============================================================================= + +// ============================================================================= +// RStmt +// ============================================================================= + +RStmt::RStmt( + std::unique_ptr expr) + : expr_(std::move(expr)) {} + +RExpr* RStmt::expr() { + return expr_.get(); +} + +const RExpr* RStmt::expr() const { + return expr_.get(); +} + +void RStmt::set_expr(std::unique_ptr expr) { + expr_ = std::move(expr); +} + +// ============================================================================= +// RNode +// ============================================================================= + +RNode::RNode( + std::unique_ptr expr, + std::optional> optional_expr, + std::vector> exprs, + std::unique_ptr stmt, + std::optional> optional_stmt, + std::vector> stmts) + : expr_(std::move(expr)), + optional_expr_(std::move(optional_expr)), + exprs_(std::move(exprs)), + stmt_(std::move(stmt)), + optional_stmt_(std::move(optional_stmt)), + stmts_(std::move(stmts)) {} + +RExpr* RNode::expr() { + return expr_.get(); +} + +const RExpr* RNode::expr() const { + return expr_.get(); +} + +void RNode::set_expr(std::unique_ptr expr) { + expr_ = std::move(expr); +} + +std::optional RNode::optional_expr() { + if (!optional_expr_.has_value()) { + return std::nullopt; + } else { + return optional_expr_.value().get(); + } +} + +std::optional RNode::optional_expr() const { + if (!optional_expr_.has_value()) { + return std::nullopt; + } else { + return optional_expr_.value().get(); + } +} + +void RNode::set_optional_expr(std::optional> optional_expr) { + optional_expr_ = std::move(optional_expr); +} + +std::vector>* RNode::exprs() { + return &exprs_; +} + +const std::vector>* RNode::exprs() const { + return &exprs_; +} + +void RNode::set_exprs(std::vector> exprs) { + exprs_ = std::move(exprs); +} + +RStmt* RNode::stmt() { + return stmt_.get(); +} + +const RStmt* RNode::stmt() const { + return stmt_.get(); +} + +void RNode::set_stmt(std::unique_ptr stmt) { + stmt_ = std::move(stmt); +} + +std::optional RNode::optional_stmt() { + if (!optional_stmt_.has_value()) { + return std::nullopt; + } else { + return optional_stmt_.value().get(); + } +} + +std::optional RNode::optional_stmt() const { + if (!optional_stmt_.has_value()) { + return std::nullopt; + } else { + return optional_stmt_.value().get(); + } +} + +void RNode::set_optional_stmt(std::optional> optional_stmt) { + optional_stmt_ = std::move(optional_stmt); +} + +std::vector>* RNode::stmts() { + return &stmts_; +} + +const std::vector>* RNode::stmts() const { + return &stmts_; +} + +void RNode::set_stmts(std::vector> stmts) { + stmts_ = std::move(stmts); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/region/ast.generated.h b/maldoca/astgen/test/region/ast.generated.h new file mode 100644 index 0000000..04c0a45 --- /dev/null +++ b/maldoca/astgen/test/region/ast.generated.h @@ -0,0 +1,145 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_REGION_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_REGION_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +class RExpr { + public: + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class RStmt { + public: + explicit RStmt( + std::unique_ptr expr); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + RExpr* expr(); + const RExpr* expr() const; + void set_expr(std::unique_ptr expr); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetExpr(const nlohmann::json& json); + + private: + std::unique_ptr expr_; +}; + +class RNode { + public: + explicit RNode( + std::unique_ptr expr, + std::optional> optional_expr, + std::vector> exprs, + std::unique_ptr stmt, + std::optional> optional_stmt, + std::vector> stmts); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + RExpr* expr(); + const RExpr* expr() const; + void set_expr(std::unique_ptr expr); + + std::optional optional_expr(); + std::optional optional_expr() const; + void set_optional_expr(std::optional> optional_expr); + + std::vector>* exprs(); + const std::vector>* exprs() const; + void set_exprs(std::vector> exprs); + + RStmt* stmt(); + const RStmt* stmt() const; + void set_stmt(std::unique_ptr stmt); + + std::optional optional_stmt(); + std::optional optional_stmt() const; + void set_optional_stmt(std::optional> optional_stmt); + + std::vector>* stmts(); + const std::vector>* stmts() const; + void set_stmts(std::vector> stmts); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetExpr(const nlohmann::json& json); + static absl::StatusOr>> GetOptionalExpr(const nlohmann::json& json); + static absl::StatusOr>> GetExprs(const nlohmann::json& json); + static absl::StatusOr> GetStmt(const nlohmann::json& json); + static absl::StatusOr>> GetOptionalStmt(const nlohmann::json& json); + static absl::StatusOr>> GetStmts(const nlohmann::json& json); + + private: + std::unique_ptr expr_; + std::optional> optional_expr_; + std::vector> exprs_; + std::unique_ptr stmt_; + std::optional> optional_stmt_; + std::vector> stmts_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_REGION_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/region/ast_def.textproto b/maldoca/astgen/test/region/ast_def.textproto new file mode 100644 index 0000000..26c0434 --- /dev/null +++ b/maldoca/astgen/test/region/ast_def.textproto @@ -0,0 +1,95 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "r" + +# interface Expr {} +nodes { + name: "Expr" + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface Stmt { +# expr: Expr +# } +nodes { + name: "Stmt" + fields { + name: "expr" + type { + class: "Expr" + } + kind: FIELD_KIND_RVAL + } + kinds: FIELD_KIND_STMT + should_generate_ir_op: true +} + +# interface Node { +# expr: Expr +# optionalExpr: Expr | null +# exprs: [Expr] +# stmt: Stmt +# optionalStmt: Stmt | null +# stmts: [Stmt] +# } +nodes { + name: "Node" + fields { + name: "expr" + type { + class: "Expr" + } + enclose_in_region: true + kind: FIELD_KIND_RVAL + } + fields { + name: "optionalExpr" + type { + class: "Expr" + } + optionalness: OPTIONALNESS_MAYBE_NULL + enclose_in_region: true + kind: FIELD_KIND_RVAL + } + fields { + name: "exprs" + type { + list { + element_type { class: "Expr" } + } + } + enclose_in_region: true + kind: FIELD_KIND_RVAL + } + fields { + name: "stmt" + type { + class: "Stmt" + } + enclose_in_region: true + kind: FIELD_KIND_STMT + } + fields { + name: "optionalStmt" + type { + class: "Stmt" + } + optionalness: OPTIONALNESS_MAYBE_NULL + enclose_in_region: true + kind: FIELD_KIND_STMT + } + fields { + name: "stmts" + type { + list { + element_type { class: "Stmt" } + } + } + enclose_in_region: true + kind: FIELD_KIND_STMT + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} diff --git a/maldoca/astgen/test/region/ast_from_json.generated.cc b/maldoca/astgen/test/region/ast_from_json.generated.cc new file mode 100644 index 0000000..f29c8b3 --- /dev/null +++ b/maldoca/astgen/test/region/ast_from_json.generated.cc @@ -0,0 +1,226 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/region/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +// ============================================================================= +// RExpr +// ============================================================================= + +absl::StatusOr> +RExpr::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + + return absl::make_unique( + ); +} + +// ============================================================================= +// RStmt +// ============================================================================= + +absl::StatusOr> +RStmt::GetExpr(const nlohmann::json& json) { + auto expr_it = json.find("expr"); + if (expr_it == json.end()) { + return absl::InvalidArgumentError("`expr` is undefined."); + } + const nlohmann::json& json_expr = expr_it.value(); + + if (json_expr.is_null()) { + return absl::InvalidArgumentError("json_expr is null."); + } + return RExpr::FromJson(json_expr); +} + +absl::StatusOr> +RStmt::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto expr, RStmt::GetExpr(json)); + + return absl::make_unique( + std::move(expr)); +} + +// ============================================================================= +// RNode +// ============================================================================= + +absl::StatusOr> +RNode::GetExpr(const nlohmann::json& json) { + auto expr_it = json.find("expr"); + if (expr_it == json.end()) { + return absl::InvalidArgumentError("`expr` is undefined."); + } + const nlohmann::json& json_expr = expr_it.value(); + + if (json_expr.is_null()) { + return absl::InvalidArgumentError("json_expr is null."); + } + return RExpr::FromJson(json_expr); +} + +absl::StatusOr>> +RNode::GetOptionalExpr(const nlohmann::json& json) { + auto optional_expr_it = json.find("optionalExpr"); + if (optional_expr_it == json.end()) { + return absl::InvalidArgumentError("`optionalExpr` is undefined."); + } + const nlohmann::json& json_optional_expr = optional_expr_it.value(); + + if (json_optional_expr.is_null()) { + return std::nullopt; + } + return RExpr::FromJson(json_optional_expr); +} + +absl::StatusOr>> +RNode::GetExprs(const nlohmann::json& json) { + auto exprs_it = json.find("exprs"); + if (exprs_it == json.end()) { + return absl::InvalidArgumentError("`exprs` is undefined."); + } + const nlohmann::json& json_exprs = exprs_it.value(); + + if (json_exprs.is_null()) { + return absl::InvalidArgumentError("json_exprs is null."); + } + if (!json_exprs.is_array()) { + return absl::InvalidArgumentError("json_exprs expected to be array."); + } + + std::vector> exprs; + for (const nlohmann::json& json_exprs_element : json_exprs) { + if (json_exprs_element.is_null()) { + return absl::InvalidArgumentError("json_exprs_element is null."); + } + MALDOCA_ASSIGN_OR_RETURN(auto exprs_element, RExpr::FromJson(json_exprs_element)); + exprs.push_back(std::move(exprs_element)); + } + return exprs; +} + +absl::StatusOr> +RNode::GetStmt(const nlohmann::json& json) { + auto stmt_it = json.find("stmt"); + if (stmt_it == json.end()) { + return absl::InvalidArgumentError("`stmt` is undefined."); + } + const nlohmann::json& json_stmt = stmt_it.value(); + + if (json_stmt.is_null()) { + return absl::InvalidArgumentError("json_stmt is null."); + } + return RStmt::FromJson(json_stmt); +} + +absl::StatusOr>> +RNode::GetOptionalStmt(const nlohmann::json& json) { + auto optional_stmt_it = json.find("optionalStmt"); + if (optional_stmt_it == json.end()) { + return absl::InvalidArgumentError("`optionalStmt` is undefined."); + } + const nlohmann::json& json_optional_stmt = optional_stmt_it.value(); + + if (json_optional_stmt.is_null()) { + return std::nullopt; + } + return RStmt::FromJson(json_optional_stmt); +} + +absl::StatusOr>> +RNode::GetStmts(const nlohmann::json& json) { + auto stmts_it = json.find("stmts"); + if (stmts_it == json.end()) { + return absl::InvalidArgumentError("`stmts` is undefined."); + } + const nlohmann::json& json_stmts = stmts_it.value(); + + if (json_stmts.is_null()) { + return absl::InvalidArgumentError("json_stmts is null."); + } + if (!json_stmts.is_array()) { + return absl::InvalidArgumentError("json_stmts expected to be array."); + } + + std::vector> stmts; + for (const nlohmann::json& json_stmts_element : json_stmts) { + if (json_stmts_element.is_null()) { + return absl::InvalidArgumentError("json_stmts_element is null."); + } + MALDOCA_ASSIGN_OR_RETURN(auto stmts_element, RStmt::FromJson(json_stmts_element)); + stmts.push_back(std::move(stmts_element)); + } + return stmts; +} + +absl::StatusOr> +RNode::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto expr, RNode::GetExpr(json)); + MALDOCA_ASSIGN_OR_RETURN(auto optional_expr, RNode::GetOptionalExpr(json)); + MALDOCA_ASSIGN_OR_RETURN(auto exprs, RNode::GetExprs(json)); + MALDOCA_ASSIGN_OR_RETURN(auto stmt, RNode::GetStmt(json)); + MALDOCA_ASSIGN_OR_RETURN(auto optional_stmt, RNode::GetOptionalStmt(json)); + MALDOCA_ASSIGN_OR_RETURN(auto stmts, RNode::GetStmts(json)); + + return absl::make_unique( + std::move(expr), + std::move(optional_expr), + std::move(exprs), + std::move(stmt), + std::move(optional_stmt), + std::move(stmts)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/region/ast_gen_test.cc b/maldoca/astgen/test/region/ast_gen_test.cc new file mode 100644 index 0000000..b5bac7c --- /dev/null +++ b/maldoca/astgen/test/region/ast_gen_test.cc @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + Region, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = "maldoca/astgen/test/" + "region/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "region/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/region", + .ir_path = "maldoca/astgen/test/region", + .expected_ast_header_path = + "maldoca/astgen/test/region/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/region/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/region/" + "ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/region/" + "ast_from_json.generated.cc", + .expected_ir_tablegen_path = + "maldoca/astgen/test/region/" + "rir_ops.generated.td", + .expected_ast_to_ir_source_path = + "maldoca/astgen/test/region/conversion/" + "ast_to_rir.generated.cc", + .expected_ir_to_ast_source_path = + "maldoca/astgen/test/region/conversion/" + "rir_to_ast.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/region/ast_to_json.generated.cc b/maldoca/astgen/test/region/ast_to_json.generated.cc new file mode 100644 index 0000000..bf661f9 --- /dev/null +++ b/maldoca/astgen/test/region/ast_to_json.generated.cc @@ -0,0 +1,141 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/region/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// RExpr +// ============================================================================= + +void RExpr::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +void RExpr::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + RExpr::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// RStmt +// ============================================================================= + +void RStmt::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"expr\":"; + expr_->Serialize(os); +} + +void RStmt::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + RStmt::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// RNode +// ============================================================================= + +void RNode::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"expr\":"; + expr_->Serialize(os); + MaybeAddComma(os, needs_comma); + if (optional_expr_.has_value()) { + os << "\"optionalExpr\":"; + optional_expr_.value()->Serialize(os); + } else { + os << "\"optionalExpr\":" << "null"; + } + MaybeAddComma(os, needs_comma); + os << "\"exprs\":" << "["; + { + bool needs_comma = false; + for (const auto& element : exprs_) { + MaybeAddComma(os, needs_comma); + element->Serialize(os); + } + } + os << "]"; + MaybeAddComma(os, needs_comma); + os << "\"stmt\":"; + stmt_->Serialize(os); + MaybeAddComma(os, needs_comma); + if (optional_stmt_.has_value()) { + os << "\"optionalStmt\":"; + optional_stmt_.value()->Serialize(os); + } else { + os << "\"optionalStmt\":" << "null"; + } + MaybeAddComma(os, needs_comma); + os << "\"stmts\":" << "["; + { + bool needs_comma = false; + for (const auto& element : stmts_) { + MaybeAddComma(os, needs_comma); + element->Serialize(os); + } + } + os << "]"; +} + +void RNode::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + RNode::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/region/ast_ts_interface.generated b/maldoca/astgen/test/region/ast_ts_interface.generated new file mode 100644 index 0000000..28a5b57 --- /dev/null +++ b/maldoca/astgen/test/region/ast_ts_interface.generated @@ -0,0 +1,15 @@ +interface Expr { +} + +interface Stmt { + expr: Expr +} + +interface Node { + expr: Expr + optionalExpr: Expr | null + exprs: [ Expr ] + stmt: Stmt + optionalStmt: Stmt | null + stmts: [ Stmt ] +} diff --git a/maldoca/astgen/test/region/conversion/BUILD b/maldoca/astgen/test/region/conversion/BUILD new file mode 100644 index 0000000..b978eb6 --- /dev/null +++ b/maldoca/astgen/test/region/conversion/BUILD @@ -0,0 +1,82 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_applicable_licenses = ["//:license"]) + +licenses(["notice"]) + +exports_files([ + "ast_to_rir.generated.cc", + "rir_to_ast.generated.cc", +]) + +cc_library( + name = "ast_to_rir", + srcs = ["ast_to_rir.generated.cc"], + hdrs = ["ast_to_rir.h"], + deps = [ + "//maldoca/astgen/test/region:ast", + "//maldoca/astgen/test/region:ir", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "rir_to_ast", + srcs = ["rir_to_ast.generated.cc"], + hdrs = ["rir_to_ast.h"], + deps = [ + "//maldoca/astgen/test/region:ast", + "//maldoca/astgen/test/region:ir", + "//maldoca/base:status", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_test( + name = "conversion_test", + srcs = ["conversion_test.cc"], + deps = [ + ":ast_to_rir", + ":rir_to_ast", + "//maldoca/astgen/test:conversion_test_util", + "//maldoca/astgen/test/region:ast", + "//maldoca/astgen/test/region:ir", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/maldoca/astgen/test/region/conversion/ast_to_rir.generated.cc b/maldoca/astgen/test/region/conversion/ast_to_rir.generated.cc new file mode 100644 index 0000000..0b6cf56 --- /dev/null +++ b/maldoca/astgen/test/region/conversion/ast_to_rir.generated.cc @@ -0,0 +1,103 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/region/conversion/ast_to_rir.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/astgen/test/region/ast.generated.h" +#include "maldoca/astgen/test/region/ir.h" + +namespace maldoca { + +RirExprOp AstToRir::VisitExpr(const RExpr *node) { + return CreateExpr(node); +} + +RirStmtOp AstToRir::VisitStmt(const RStmt *node) { + mlir::Value mlir_expr = VisitExpr(node->expr()); + return CreateStmt(node, mlir_expr); +} + +RirNodeOp AstToRir::VisitNode(const RNode *node) { + auto op = CreateExpr(node); + mlir::Region &mlir_expr_region = op.getExpr(); + AppendNewBlockAndPopulate(mlir_expr_region, [&] { + mlir::Value mlir_expr = VisitExpr(node->expr()); + CreateStmt(node, mlir_expr); + }); + if (node->optional_expr().has_value()) { + mlir::Region &mlir_optional_expr_region = op.getOptionalExpr(); + AppendNewBlockAndPopulate(mlir_optional_expr_region, [&] { + mlir::Value mlir_optional_expr = VisitExpr(node->optional_expr().value()); + CreateStmt(node, mlir_optional_expr); + }); + } + mlir::Region &mlir_exprs_region = op.getExprs(); + AppendNewBlockAndPopulate(mlir_exprs_region, [&] { + std::vector mlir_exprs; + for (const auto &element : *node->exprs()) { + mlir::Value mlir_element = VisitExpr(element.get()); + mlir_exprs.push_back(std::move(mlir_element)); + } + CreateStmt(node, mlir_exprs); + }); + mlir::Region &mlir_stmt_region = op.getStmt(); + AppendNewBlockAndPopulate(mlir_stmt_region, [&] { + VisitStmt(node->stmt()); + }); + if (node->optional_stmt().has_value()) { + mlir::Region &mlir_optional_stmt_region = op.getOptionalStmt(); + AppendNewBlockAndPopulate(mlir_optional_stmt_region, [&] { + VisitStmt(node->optional_stmt().value()); + }); + } + mlir::Region &mlir_stmts_region = op.getStmts(); + AppendNewBlockAndPopulate(mlir_stmts_region, [&] { + for (const auto &element : *node->stmts()) { + VisitStmt(element.get()); + } + }); + return op; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/region/conversion/ast_to_rir.h b/maldoca/astgen/test/region/conversion/ast_to_rir.h new file mode 100644 index 0000000..33cbd8c --- /dev/null +++ b/maldoca/astgen/test/region/conversion/ast_to_rir.h @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_REGION_CONVERSION_AST_TO_RIR_H_ +#define MALDOCA_ASTGEN_TEST_REGION_CONVERSION_AST_TO_RIR_H_ + +#include +#include + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Block.h" +#include "maldoca/astgen/test/region/ast.generated.h" +#include "maldoca/astgen/test/region/ir.h" + +namespace maldoca { + +class AstToRir { + public: + explicit AstToRir(mlir::OpBuilder &builder) : builder_(builder) {} + + RirExprOp VisitExpr(const RExpr *node); + + RirStmtOp VisitStmt(const RStmt *node); + + RirNodeOp VisitNode(const RNode *node); + + private: + template + Op CreateExpr(const RNode *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), + std::forward(args)...); + } + + template + Op CreateStmt(const RNode *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), std::nullopt, + std::forward(args)...); + } + + void AppendNewBlockAndPopulate(mlir::Region ®ion, + std::function populate) { + // Save insertion point. + // Will revert at the end. + mlir::OpBuilder::InsertionGuard insertion_guard(builder_); + + // Insert new block and point builder to it. + mlir::Block &block = region.emplaceBlock(); + builder_.setInsertionPointToStart(&block); + + populate(); + } + + mlir::OpBuilder &builder_; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_REGION_CONVERSION_AST_TO_RIR_H_ diff --git a/maldoca/astgen/test/region/conversion/conversion_test.cc b/maldoca/astgen/test/region/conversion/conversion_test.cc new file mode 100644 index 0000000..8835ada --- /dev/null +++ b/maldoca/astgen/test/region/conversion/conversion_test.cc @@ -0,0 +1,91 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/astgen/test/conversion_test_util.h" +#include "maldoca/astgen/test/region/ast.generated.h" +#include "maldoca/astgen/test/region/conversion/ast_to_rir.h" +#include "maldoca/astgen/test/region/conversion/rir_to_ast.h" +#include "maldoca/astgen/test/region/ir.h" + +namespace maldoca { +namespace { + +TEST(ConversionTest, ConversionTest) { + constexpr char kAstJsonString[] = R"( + { + "expr": {}, + "exprs": [ + {} + ], + "optionalExpr": null, + "optionalStmt": null, + "stmt": { + "expr": {} + }, + "stmts": [ + { + "expr": {} + } + ] + } + )"; + + constexpr char kExpectedIr[] = R"( +module { + %0 = "rir.node"() ({ + %1 = "rir.expr"() : () -> !rir.any + "rir.expr_region_end"(%1) : (!rir.any) -> () + }, { + }, { + %1 = "rir.expr"() : () -> !rir.any + "rir.exprs_region_end"(%1) : (!rir.any) -> () + }, { + %1 = "rir.expr"() : () -> !rir.any + "rir.stmt"(%1) : (!rir.any) -> () + }, { + }, { + %1 = "rir.expr"() : () -> !rir.any + "rir.stmt"(%1) : (!rir.any) -> () + }) : () -> !rir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToRir::VisitNode, + .ir_to_ast_visit = &RirToAst::VisitNode, + .expected_ir_dump = kExpectedIr, + }); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/region/conversion/rir_to_ast.generated.cc b/maldoca/astgen/test/region/conversion/rir_to_ast.generated.cc new file mode 100644 index 0000000..007e80f --- /dev/null +++ b/maldoca/astgen/test/region/conversion/rir_to_ast.generated.cc @@ -0,0 +1,153 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/region/conversion/rir_to_ast.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/base/status_macros.h" +#include "maldoca/astgen/test/region/ast.generated.h" +#include "maldoca/astgen/test/region/ir.h" + +namespace maldoca { + +absl::StatusOr> +RirToAst::VisitExpr(RirExprOp op) { + return Create( + op); +} + +absl::StatusOr> +RirToAst::VisitStmt(RirStmtOp op) { + auto expr_op = llvm::dyn_cast(op.getExpr().getDefiningOp()); + if (expr_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected RirExprOp, got ", + op.getExpr().getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr expr, VisitExpr(expr_op)); + return Create( + op, + std::move(expr)); +} + +absl::StatusOr> +RirToAst::VisitNode(RirNodeOp op) { + MALDOCA_ASSIGN_OR_RETURN(auto mlir_expr_value, GetExprRegionValue(op.getExpr())); + auto expr_op = llvm::dyn_cast(mlir_expr_value.getDefiningOp()); + if (expr_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected RirExprOp, got ", + mlir_expr_value.getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr expr, VisitExpr(expr_op)); + std::optional> optional_expr; + if (!op.getOptionalExpr().empty()) { + MALDOCA_ASSIGN_OR_RETURN(auto mlir_optional_expr_value, GetExprRegionValue(op.getOptionalExpr())); + auto optional_expr_op = llvm::dyn_cast(mlir_optional_expr_value.getDefiningOp()); + if (optional_expr_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected RirExprOp, got ", + mlir_optional_expr_value.getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(optional_expr, VisitExpr(optional_expr_op)); + } + MALDOCA_ASSIGN_OR_RETURN(auto mlir_exprs_values, GetExprsRegionValues(op.getExprs())); + std::vector> exprs; + for (mlir::Value mlir_exprs_element_unchecked : mlir_exprs_values) { + auto exprs_element_op = llvm::dyn_cast(mlir_exprs_element_unchecked.getDefiningOp()); + if (exprs_element_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected RirExprOp, got ", + mlir_exprs_element_unchecked.getDefiningOp()->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr exprs_element, VisitExpr(exprs_element_op)); + exprs.push_back(std::move(exprs_element)); + } + MALDOCA_ASSIGN_OR_RETURN(auto mlir_stmt_operation, GetStmtRegionOperation(op.getStmt())); + auto stmt_op = llvm::dyn_cast(mlir_stmt_operation); + if (stmt_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected RirStmtOp, got ", + mlir_stmt_operation->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr stmt, VisitStmt(stmt_op)); + std::optional> optional_stmt; + if (!op.getOptionalStmt().empty()) { + MALDOCA_ASSIGN_OR_RETURN(auto mlir_optional_stmt_operation, GetStmtRegionOperation(op.getOptionalStmt())); + auto optional_stmt_op = llvm::dyn_cast(mlir_optional_stmt_operation); + if (optional_stmt_op == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Expected RirStmtOp, got ", + mlir_optional_stmt_operation->getName().getStringRef().str(), ".")); + } + MALDOCA_ASSIGN_OR_RETURN(optional_stmt, VisitStmt(optional_stmt_op)); + } + MALDOCA_ASSIGN_OR_RETURN(auto mlir_stmts_block, GetStmtsRegionBlock(op.getStmts())); + std::vector> stmts; + for (mlir::Operation& mlir_stmts_element_unchecked : *mlir_stmts_block) { + auto stmts_element_op = llvm::dyn_cast(mlir_stmts_element_unchecked); + if (stmts_element_op == nullptr) { + continue; + } + MALDOCA_ASSIGN_OR_RETURN(std::unique_ptr stmts_element, VisitStmt(stmts_element_op)); + stmts.push_back(std::move(stmts_element)); + } + return Create( + op, + std::move(expr), + std::move(optional_expr), + std::move(exprs), + std::move(stmt), + std::move(optional_stmt), + std::move(stmts)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/region/conversion/rir_to_ast.h b/maldoca/astgen/test/region/conversion/rir_to_ast.h new file mode 100644 index 0000000..c225479 --- /dev/null +++ b/maldoca/astgen/test/region/conversion/rir_to_ast.h @@ -0,0 +1,106 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_REGION_CONVERSION_RIR_TO_AST_H_ +#define MALDOCA_ASTGEN_TEST_REGION_CONVERSION_RIR_TO_AST_H_ + +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "maldoca/astgen/test/region/ast.generated.h" +#include "maldoca/astgen/test/region/ir.h" + +namespace maldoca { + +class RirToAst { + public: + absl::StatusOr> VisitExpr(RirExprOp op); + + absl::StatusOr> VisitStmt(RirStmtOp op); + + absl::StatusOr> VisitNode(RirNodeOp op); + + private: + template + std::unique_ptr Create(mlir::Operation *op, Args &&...args) { + return absl::make_unique(std::forward(args)...); + } + + absl::StatusOr GetExprRegionValue(mlir::Region ®ion) { + if (!region.hasOneBlock()) { + return absl::InvalidArgumentError( + "Region should have exactly one block."); + } + mlir::Block &block = region.front(); + if (block.empty()) { + return absl::InvalidArgumentError("Block cannot be empty."); + } + auto expr_region_end = llvm::dyn_cast(block.back()); + if (expr_region_end == nullptr) { + return absl::InvalidArgumentError( + "Block should end with RirExprRegionEndOp."); + } + return expr_region_end.getArgument(); + } + + absl::StatusOr GetExprsRegionValues(mlir::Region ®ion) { + if (!region.hasOneBlock()) { + return absl::InvalidArgumentError( + "Region should have exactly one block."); + } + mlir::Block &block = region.front(); + if (block.empty()) { + return absl::InvalidArgumentError("Block cannot be empty."); + } + auto exprs_region_end = llvm::dyn_cast(block.back()); + if (exprs_region_end == nullptr) { + return absl::InvalidArgumentError( + "Block should end with RirExprsRegionEndOp."); + } + return exprs_region_end.getArguments(); + } + + absl::StatusOr GetStmtRegionOperation( + mlir::Region ®ion) { + if (!region.hasOneBlock()) { + return absl::InvalidArgumentError( + "Region should have exactly one block."); + } + mlir::Block &block = region.front(); + if (block.empty()) { + return absl::InvalidArgumentError("Block cannot be empty."); + } + return &block.back(); + } + + absl::StatusOr GetStmtsRegionBlock(mlir::Region ®ion) { + if (!region.hasOneBlock()) { + return absl::InvalidArgumentError( + "Region should have exactly one block."); + } + mlir::Block &block = region.front(); + return █ + } +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_REGION_CONVERSION_RIR_TO_AST_H_ diff --git a/maldoca/astgen/test/region/interfaces.td b/maldoca/astgen/test/region/interfaces.td new file mode 100644 index 0000000..5d18117 --- /dev/null +++ b/maldoca/astgen/test/region/interfaces.td @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "mlir/IR/OpBase.td" diff --git a/maldoca/astgen/test/region/ir.cc b/maldoca/astgen/test/region/ir.cc new file mode 100644 index 0000000..6fb3cf5 --- /dev/null +++ b/maldoca/astgen/test/region/ir.cc @@ -0,0 +1,131 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/test/region/ir.h" + +// IWYU pragma: begin_keep + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Region.h" + +// IWYU pragma: end_keep + +// ============================================================================= +// Dialect Definition +// ============================================================================= + +#include "maldoca/astgen/test/region/rir_dialect.cc.inc" + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void maldoca::RirDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "maldoca/astgen/test/region/rir_types.cc.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "maldoca/astgen/test/region/rir_ops.cc.inc" + >(); +} + +// ============================================================================= +// Dialect Interface Definitions +// ============================================================================= + +#include "maldoca/astgen/test/region/interfaces.cc.inc" + +// ============================================================================= +// Dialect Type Definitions +// ============================================================================= + +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/region/rir_types.cc.inc" + +// ============================================================================= +// Dialect Op Definitions +// ============================================================================= + +#define GET_OP_CLASSES +#include "maldoca/astgen/test/region/rir_ops.cc.inc" + +// ============================================================================= +// Utils +// ============================================================================= + +namespace maldoca { + +bool IsUnknownRegion(mlir::Region ®ion) { + // Region must have exactly one block. + return llvm::hasSingleElement(region); +} + +bool IsExprRegion(mlir::Region ®ion) { + // Region must have exactly one block. + if (!llvm::hasSingleElement(region)) { + return false; + } + + mlir::Block &block = region.front(); + + // Block must have at least one op (terminator). + if (block.empty()) { + return false; + } + + auto *terminator = &block.back(); + return llvm::isa(terminator); +} + +bool IsExprsRegion(mlir::Region ®ion) { + // Region must have exactly one block. + if (!llvm::hasSingleElement(region)) { + return false; + } + + mlir::Block &block = region.front(); + + // Block must have at least one op (terminator). + if (block.empty()) { + return false; + } + + auto *terminator = &block.back(); + return llvm::isa(terminator); +} + +bool IsStmtRegion(mlir::Region ®ion) { + // Region must have exactly one block. + if (!llvm::hasSingleElement(region)) { + return false; + } + + mlir::Block &block = region.front(); + + // Block must have at least one op (the statement). + return !block.empty(); +} + +bool IsStmtsRegion(mlir::Region ®ion) { + // Region must have exactly one block. + return llvm::hasSingleElement(region); +} + +} // namespace maldoca diff --git a/maldoca/astgen/test/region/ir.h b/maldoca/astgen/test/region/ir.h new file mode 100644 index 0000000..fa46364 --- /dev/null +++ b/maldoca/astgen/test/region/ir.h @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_REGION_IR_H_ +#define MALDOCA_ASTGEN_TEST_REGION_IR_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +namespace maldoca { + +// Checks that region contains a single block. There is no restriction on the +// block. This means that this region can be either of the four kinds below. +bool IsUnknownRegion(mlir::Region ®ion); + +// Checks that the region contains a single block that terminates with +// JsirExprRegionEnd. This means that this region calculates a single +// expression. +bool IsExprRegion(mlir::Region ®ion); + +// Checks that the region contains a single block that terminates with +// JsirExprsRegionEnd. This means that this region calculates a list of +// expressions. +bool IsExprsRegion(mlir::Region ®ion); + +// Checks that the region contains a single block that's non empty. This means +// that this region contains a statement. +bool IsStmtRegion(mlir::Region ®ion); + +// Checks that the region contains a single block. There is no restriction on +// the block. This means that this region contains a list of statements (the +// most unrestrictive case). +bool IsStmtsRegion(mlir::Region ®ion); + +} // namespace maldoca + +// Include the auto-generated header file containing the declaration of the RIR +// dialect. +#include "maldoca/astgen/test/region/rir_dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the RIR +// interfaces. +#include "maldoca/astgen/test/region/interfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the RIR +// types. +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/region/rir_types.h.inc" + +// Include the auto-generated header file containing the declarations of the RIR +// operations. +#define GET_OP_CLASSES +#include "maldoca/astgen/test/region/rir_ops.h.inc" + +#endif // MALDOCA_ASTGEN_TEST_REGION_IR_H_ diff --git a/maldoca/astgen/test/region/rir_dialect.td b/maldoca/astgen/test/region/rir_dialect.td new file mode 100644 index 0000000..5f4ea18 --- /dev/null +++ b/maldoca/astgen/test/region/rir_dialect.td @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_REGION_RIR_DIALECT_TD_ +#define MALDOCA_ASTGEN_TEST_REGION_RIR_DIALECT_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def Rir_Dialect : Dialect { + let name = "rir"; + let cppNamespace = "::maldoca"; + + let description = [{ + The RegionIR, a test IR that has fields embedded in regions. All ops and + fields are directly mapped from the AST. + }]; + + let useDefaultTypePrinterParser = 1; +} + +class Rir_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = ?; +} + +class Rir_Op traits = []> : + Op; + +def UnknownRegion : Region>; + +def ExprRegion : Region>; + +def ExprsRegion : Region>; + +def StmtRegion : Region>; + +def StmtsRegion : Region>; + +def RegionIsEmpty : CPred<"$_self.empty()">; + +class OptionalRegion : Region< + Or<[region.predicate, RegionIsEmpty]>, + region.summary +>; + +#endif // MALDOCA_ASTGEN_TEST_REGION_RIR_DIALECT_TD_ diff --git a/maldoca/astgen/test/region/rir_ops.generated.td b/maldoca/astgen/test/region/rir_ops.generated.td new file mode 100644 index 0000000..8540e39 --- /dev/null +++ b/maldoca/astgen/test/region/rir_ops.generated.td @@ -0,0 +1,144 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_REGION_RIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_REGION_RIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/region/interfaces.td" +include "maldoca/astgen/test/region/rir_dialect.td" +include "maldoca/astgen/test/region/rir_types.td" + +// rir.*_region_end: An artificial op at the end of a region to collect +// expression-related values. +// +// Take rir.exprs_region_end as example: +// ====================================== +// +// Consider the following function declaration: +// ``` +// function foo(arg1, arg2 = defaultValue) { +// ... +// } +// ``` +// +// We lower it to the following IR (simplified): +// ``` +// %0 = rir.identifier_ref {"foo"} +// rir.function_declaration(%0) ( +// // params +// { +// %1 = rir.identifier_ref {"a"} +// %2 = rir.identifier_ref {"b"} +// %3 = rir.identifier {"defaultValue"} +// %4 = rir.assignment_pattern_ref(%2, %3) +// rir.exprs_region_end(%1, %4) +// }, +// // body +// { +// ... +// } +// ) +// ``` +// +// We can see that: +// +// 1. We put the parameter-related ops in a region, instead of taking them as +// normal arguments. In other words, we don't do this: +// +// ``` +// %0 = rir.identifier_ref {"foo"} +// %1 = rir.identifier_ref {"a"} +// %2 = rir.identifier_ref {"b"} +// %3 = rir.identifier {"defaultValue"} +// %4 = rir.assignment_pattern_ref(%2, %3) +// rir.function_declaration(%0, [%1, %4]) ( +// // body +// { +// ... +// } +// ) +// ``` +// +// The reason is that sometimes an argument might have a default value, and +// the evaluation of that default value happens once for each function call +// (i.e. it happens "within" the function). If we take the parameter as +// normal argument, then %3 is only evaluated once - at function definition +// time. +// +// 2. Even though the function has two parameters, we use 4 ops to represent +// them. This is because some parameters are more complex and require more +// than one op. +// +// 3. We use "rir.exprs_region_end" to list the "top-level" ops for the +// parameters. In the example above, ops [%2, %3, %4] all represent the +// parameter "b = defaultValue", but %4 is the top-level one. In other words, +// %4 is the root of the tree [%2, %3, %4]. +// +// 4. Strictly speaking, we don't really need "rir.exprs_region_end". The ops +// within the "params" region form several trees, and we can figure out what +// the roots are (a root is an op whose return value is not used by any other +// op). So the use of "rir.exprs_region_end" is mostly for convenience. +def RirExprRegionEndOp : Rir_Op<"expr_region_end", [Terminator]> { + let arguments = (ins + AnyType: $argument + ); +} + +def RirExprsRegionEndOp : Rir_Op<"exprs_region_end", [Terminator]> { + let arguments = (ins + Variadic: $arguments + ); +} + +def RirExprOp : Rir_Op<"expr", []> { + let results = (outs + RirAnyType + ); +} + +def RirStmtOp : Rir_Op<"stmt", []> { + let arguments = (ins + AnyType: $expr + ); +} + +def RirNodeOp : Rir_Op< + "node", [ + NoTerminator + ]> { + let regions = (region + ExprRegion: $expr, + OptionalRegion: $optional_expr, + ExprsRegion: $exprs, + StmtRegion: $stmt, + OptionalRegion: $optional_stmt, + StmtsRegion: $stmts + ); + + let results = (outs + RirAnyType + ); +} + +#endif // MALDOCA_ASTGEN_TEST_REGION_RIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/region/rir_ops.td b/maldoca/astgen/test/region/rir_ops.td new file mode 100644 index 0000000..52f4e60 --- /dev/null +++ b/maldoca/astgen/test/region/rir_ops.td @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_REGION_RIR_OPS_TD_ +#define MALDOCA_ASTGEN_TEST_REGION_RIR_OPS_TD_ + +// Import the generated ops. +include "maldoca/astgen/test/region/rir_ops.generated.td" + +#endif // MALDOCA_ASTGEN_TEST_REGION_RIR_OPS_TD_ diff --git a/maldoca/astgen/test/region/rir_types.td b/maldoca/astgen/test/region/rir_types.td new file mode 100644 index 0000000..34dadac --- /dev/null +++ b/maldoca/astgen/test/region/rir_types.td @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_REGION_RIR_TYPES_TD_ +#define MALDOCA_ASTGEN_TEST_REGION_RIR_TYPES_TD_ + +include "maldoca/astgen/test/region/rir_dialect.td" + +def RirAnyType : Rir_Type<"RirAny"> { + let summary = "A placeholder singleton type."; + let mnemonic = "any"; + let assemblyFormat = ""; +} + +#endif // MALDOCA_ASTGEN_TEST_REGION_RIR_TYPES_TD_ diff --git a/maldoca/astgen/test/typed_lambda/BUILD b/maldoca/astgen/test/typed_lambda/BUILD new file mode 100644 index 0000000..325c332 --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/BUILD @@ -0,0 +1,69 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package(default_applicable_licenses = ["//:license"]) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = ["ast.generated.h"], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) + +cc_test( + name = "ast_test", + srcs = ["ast_test.cc"], + deps = [ + ":ast", + "//maldoca/base/testing:status_matchers", + "@googletest//:gtest_main", + "@nlohmann_json//:json", + ], +) diff --git a/maldoca/astgen/test/typed_lambda/ast.generated.cc b/maldoca/astgen/test/typed_lambda/ast.generated.cc new file mode 100644 index 0000000..837b40d --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast.generated.cc @@ -0,0 +1,272 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/typed_lambda/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// TlNode +// ============================================================================= + +absl::string_view TlNodeTypeToString(TlNodeType node_type) { + switch (node_type) { + case TlNodeType::kLiteral: + return "Literal"; + case TlNodeType::kVariable: + return "Variable"; + case TlNodeType::kFunctionDefinition: + return "FunctionDefinition"; + case TlNodeType::kFunctionCall: + return "FunctionCall"; + case TlNodeType::kLiteralType: + return "LiteralType"; + case TlNodeType::kFunctionType: + return "FunctionType"; + } +} + +absl::StatusOr StringToTlNodeType(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"Literal", TlNodeType::kLiteral}, + {"Variable", TlNodeType::kVariable}, + {"FunctionDefinition", TlNodeType::kFunctionDefinition}, + {"FunctionCall", TlNodeType::kFunctionCall}, + {"LiteralType", TlNodeType::kLiteralType}, + {"FunctionType", TlNodeType::kFunctionType}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for TlNodeType: ", s)); + } + return it->second; +} + +// ============================================================================= +// TlExpression +// ============================================================================= + +// ============================================================================= +// TlType +// ============================================================================= + +// ============================================================================= +// TlLiteral +// ============================================================================= + +TlLiteral::TlLiteral( + std::variant value) + : TlNode(), + TlExpression(), + value_(std::move(value)) {} + +std::variant TlLiteral::value() const { + switch (value_.index()) { + case 0: { + return std::get<0>(value_); + } + case 1: { + return std::get<1>(value_); + } + case 2: { + return std::get<2>(value_); + } + case 3: { + return std::get<3>(value_); + } + default: + LOG(FATAL) << "Unreachable code."; + } +} + +void TlLiteral::set_value(std::variant value) { + value_ = std::move(value); +} + +// ============================================================================= +// TlVariable +// ============================================================================= + +TlVariable::TlVariable( + std::string identifier) + : TlNode(), + TlExpression(), + identifier_(std::move(identifier)) {} + +absl::string_view TlVariable::identifier() const { + return identifier_; +} + +void TlVariable::set_identifier(std::string identifier) { + identifier_ = std::move(identifier); +} + +// ============================================================================= +// TlFunctionDefinition +// ============================================================================= + +TlFunctionDefinition::TlFunctionDefinition( + std::unique_ptr parameter, + std::unique_ptr parameter_type, + std::unique_ptr body) + : TlNode(), + TlExpression(), + parameter_(std::move(parameter)), + parameter_type_(std::move(parameter_type)), + body_(std::move(body)) {} + +TlVariable* TlFunctionDefinition::parameter() { + return parameter_.get(); +} + +const TlVariable* TlFunctionDefinition::parameter() const { + return parameter_.get(); +} + +void TlFunctionDefinition::set_parameter(std::unique_ptr parameter) { + parameter_ = std::move(parameter); +} + +TlType* TlFunctionDefinition::parameter_type() { + return parameter_type_.get(); +} + +const TlType* TlFunctionDefinition::parameter_type() const { + return parameter_type_.get(); +} + +void TlFunctionDefinition::set_parameter_type(std::unique_ptr parameter_type) { + parameter_type_ = std::move(parameter_type); +} + +TlExpression* TlFunctionDefinition::body() { + return body_.get(); +} + +const TlExpression* TlFunctionDefinition::body() const { + return body_.get(); +} + +void TlFunctionDefinition::set_body(std::unique_ptr body) { + body_ = std::move(body); +} + +// ============================================================================= +// TlFunctionCall +// ============================================================================= + +TlFunctionCall::TlFunctionCall( + std::unique_ptr caller, + std::unique_ptr callee) + : TlNode(), + TlExpression(), + caller_(std::move(caller)), + callee_(std::move(callee)) {} + +TlExpression* TlFunctionCall::caller() { + return caller_.get(); +} + +const TlExpression* TlFunctionCall::caller() const { + return caller_.get(); +} + +void TlFunctionCall::set_caller(std::unique_ptr caller) { + caller_ = std::move(caller); +} + +TlExpression* TlFunctionCall::callee() { + return callee_.get(); +} + +const TlExpression* TlFunctionCall::callee() const { + return callee_.get(); +} + +void TlFunctionCall::set_callee(std::unique_ptr callee) { + callee_ = std::move(callee); +} + +// ============================================================================= +// TlLiteralType +// ============================================================================= + +// ============================================================================= +// TlFunctionType +// ============================================================================= + +TlFunctionType::TlFunctionType( + std::unique_ptr parameter_type, + std::unique_ptr body_type) + : TlNode(), + TlType(), + parameter_type_(std::move(parameter_type)), + body_type_(std::move(body_type)) {} + +TlType* TlFunctionType::parameter_type() { + return parameter_type_.get(); +} + +const TlType* TlFunctionType::parameter_type() const { + return parameter_type_.get(); +} + +void TlFunctionType::set_parameter_type(std::unique_ptr parameter_type) { + parameter_type_ = std::move(parameter_type); +} + +TlType* TlFunctionType::body_type() { + return body_type_.get(); +} + +const TlType* TlFunctionType::body_type() const { + return body_type_.get(); +} + +void TlFunctionType::set_body_type(std::unique_ptr body_type) { + body_type_ = std::move(body_type); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/typed_lambda/ast.generated.h b/maldoca/astgen/test/typed_lambda/ast.generated.h new file mode 100644 index 0000000..318a31f --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast.generated.h @@ -0,0 +1,292 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_TYPED_LAMBDA_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_TYPED_LAMBDA_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +enum class TlNodeType { + kLiteral, + kVariable, + kFunctionDefinition, + kFunctionCall, + kLiteralType, + kFunctionType, +}; + +absl::string_view TlNodeTypeToString(TlNodeType node_type); +absl::StatusOr StringToTlNodeType(absl::string_view s); + +class TlNode { + public: + virtual ~TlNode() = default; + + virtual TlNodeType node_type() const = 0; + + virtual void Serialize(std::ostream& os) const = 0; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class TlExpression : public virtual TlNode { + public: + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class TlType : public virtual TlNode { + public: + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class TlLiteral : public virtual TlExpression { + public: + explicit TlLiteral( + std::variant value); + + TlNodeType node_type() const override { + return TlNodeType::kLiteral; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::variant value() const; + void set_value(std::variant value); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetValue(const nlohmann::json& json); + + private: + std::variant value_; +}; + +class TlVariable : public virtual TlExpression { + public: + explicit TlVariable( + std::string identifier); + + TlNodeType node_type() const override { + return TlNodeType::kVariable; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + absl::string_view identifier() const; + void set_identifier(std::string identifier); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetIdentifier(const nlohmann::json& json); + + private: + std::string identifier_; +}; + +class TlFunctionDefinition : public virtual TlExpression { + public: + explicit TlFunctionDefinition( + std::unique_ptr parameter, + std::unique_ptr parameter_type, + std::unique_ptr body); + + TlNodeType node_type() const override { + return TlNodeType::kFunctionDefinition; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + TlVariable* parameter(); + const TlVariable* parameter() const; + void set_parameter(std::unique_ptr parameter); + + TlType* parameter_type(); + const TlType* parameter_type() const; + void set_parameter_type(std::unique_ptr parameter_type); + + TlExpression* body(); + const TlExpression* body() const; + void set_body(std::unique_ptr body); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetParameter(const nlohmann::json& json); + static absl::StatusOr> GetParameterType(const nlohmann::json& json); + static absl::StatusOr> GetBody(const nlohmann::json& json); + + private: + std::unique_ptr parameter_; + std::unique_ptr parameter_type_; + std::unique_ptr body_; +}; + +class TlFunctionCall : public virtual TlExpression { + public: + explicit TlFunctionCall( + std::unique_ptr caller, + std::unique_ptr callee); + + TlNodeType node_type() const override { + return TlNodeType::kFunctionCall; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + TlExpression* caller(); + const TlExpression* caller() const; + void set_caller(std::unique_ptr caller); + + TlExpression* callee(); + const TlExpression* callee() const; + void set_callee(std::unique_ptr callee); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetCaller(const nlohmann::json& json); + static absl::StatusOr> GetCallee(const nlohmann::json& json); + + private: + std::unique_ptr caller_; + std::unique_ptr callee_; +}; + +class TlLiteralType : public virtual TlType { + public: + TlNodeType node_type() const override { + return TlNodeType::kLiteralType; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class TlFunctionType : public virtual TlType { + public: + explicit TlFunctionType( + std::unique_ptr parameter_type, + std::unique_ptr body_type); + + TlNodeType node_type() const override { + return TlNodeType::kFunctionType; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + TlType* parameter_type(); + const TlType* parameter_type() const; + void set_parameter_type(std::unique_ptr parameter_type); + + TlType* body_type(); + const TlType* body_type() const; + void set_body_type(std::unique_ptr body_type); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetParameterType(const nlohmann::json& json); + static absl::StatusOr> GetBodyType(const nlohmann::json& json); + + private: + std::unique_ptr parameter_type_; + std::unique_ptr body_type_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_TYPED_LAMBDA_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/typed_lambda/ast_def.textproto b/maldoca/astgen/test/typed_lambda/ast_def.textproto new file mode 100644 index 0000000..0f510b8 --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast_def.textproto @@ -0,0 +1,137 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "tl" + +# interface Node { +# type: string +# } +nodes { + name: "Node" +} + +# e ::= x +# | c +# | (x: T) => e +# | e e +# +# interface Expression <: Node {} +nodes { + name: "Expression" + parents: "Node" +} + +# T ::= L +# | T -> T +# +# interface Type <: Node {} +nodes { + name: "Type" + parents: "Node" +} + +# interface Literal <: Expression { +# type: "Literal" +# value: boolean | int64 | number | string +# } +nodes { + name: "Literal" + type: "Literal" + parents: "Expression" + fields { + name: "value" + type { + variant { + types { bool {} } + types { int64 {} } + types { double {} } + types { string {} } + } + } + } +} + +# interface Variable <: Expression { +# type: "Variable" +# identifier: string +# } +nodes { + name: "Variable" + type: "Variable" + parents: "Expression" + fields { + name: "identifier" + type { string {} } + } +} + +# interface FunctionDefinition <: Expression { +# type: "FunctionDefinition" +# parameter: Variable +# parameterType: Type +# body: Expression +# } +nodes { + name: "FunctionDefinition" + type: "FunctionDefinition" + parents: "Expression" + fields { + name: "parameter" + type { class: "Variable" } + } + fields { + name: "parameterType" + type { class: "Type" } + } + fields { + name: "body" + type { class: "Expression" } + } +} + +# interface FunctionCall <: Expression { +# type: "FunctionCall" +# caller: Expression +# callee: Expression +# } +nodes { + name: "FunctionCall" + type: "FunctionCall" + parents: "Expression" + fields { + name: "caller" + type { class: "Expression" } + } + fields { + name: "callee" + type { class: "Expression" } + } +} + +# interface LiteralType <: Type { +# type: "LiteralType" +# } +nodes { + name: "LiteralType" + type: "LiteralType" + parents: "Type" +} + +# interface FunctionType <: Type { +# type: "FunctionType" +# parameterType: Type +# bodyType: Type +# } +nodes { + name: "FunctionType" + type: "FunctionType" + parents: "Type" + fields { + name: "parameterType" + type { class: "Type" } + } + fields { + name: "bodyType" + type { class: "Type" } + } +} diff --git a/maldoca/astgen/test/typed_lambda/ast_from_json.generated.cc b/maldoca/astgen/test/typed_lambda/ast_from_json.generated.cc new file mode 100644 index 0000000..c0f827f --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast_from_json.generated.cc @@ -0,0 +1,383 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/typed_lambda/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +static absl::StatusOr GetType(const nlohmann::json& json) { + auto type_it = json.find("type"); + if (type_it == json.end()) { + return absl::InvalidArgumentError("`type` is undefined."); + } + const nlohmann::json& json_type = type_it.value(); + if (json_type.is_null()) { + return absl::InvalidArgumentError("json_type is null."); + } + if (!json_type.is_string()) { + return absl::InvalidArgumentError("`json_type` expected to be string."); + } + return json_type.get(); +} + +// ============================================================================= +// TlNode +// ============================================================================= + +absl::StatusOr> +TlNode::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "Literal") { + return TlLiteral::FromJson(json); + } else if (type == "Variable") { + return TlVariable::FromJson(json); + } else if (type == "FunctionDefinition") { + return TlFunctionDefinition::FromJson(json); + } else if (type == "FunctionCall") { + return TlFunctionCall::FromJson(json); + } else if (type == "Expression") { + return TlExpression::FromJson(json); + } else if (type == "LiteralType") { + return TlLiteralType::FromJson(json); + } else if (type == "FunctionType") { + return TlFunctionType::FromJson(json); + } else if (type == "Type") { + return TlType::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// TlExpression +// ============================================================================= + +absl::StatusOr> +TlExpression::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "Literal") { + return TlLiteral::FromJson(json); + } else if (type == "Variable") { + return TlVariable::FromJson(json); + } else if (type == "FunctionDefinition") { + return TlFunctionDefinition::FromJson(json); + } else if (type == "FunctionCall") { + return TlFunctionCall::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// TlType +// ============================================================================= + +absl::StatusOr> +TlType::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "LiteralType") { + return TlLiteralType::FromJson(json); + } else if (type == "FunctionType") { + return TlFunctionType::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// TlLiteral +// ============================================================================= + +absl::StatusOr> +TlLiteral::GetValue(const nlohmann::json& json) { + auto value_it = json.find("value"); + if (value_it == json.end()) { + return absl::InvalidArgumentError("`value` is undefined."); + } + const nlohmann::json& json_value = value_it.value(); + + if (json_value.is_null()) { + return absl::InvalidArgumentError("json_value is null."); + } + if (json_value.is_boolean()) { + return json_value.get(); + } else if (json_value.is_number_integer()) { + return json_value.get(); + } else if (json_value.is_number()) { + return json_value.get(); + } else if (json_value.is_string()) { + return json_value.get(); + } else { + auto result = absl::InvalidArgumentError("json_value has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_value.dump()}); + return result; + } +} + +absl::StatusOr> +TlLiteral::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto value, TlLiteral::GetValue(json)); + + return absl::make_unique( + std::move(value)); +} + +// ============================================================================= +// TlVariable +// ============================================================================= + +absl::StatusOr +TlVariable::GetIdentifier(const nlohmann::json& json) { + auto identifier_it = json.find("identifier"); + if (identifier_it == json.end()) { + return absl::InvalidArgumentError("`identifier` is undefined."); + } + const nlohmann::json& json_identifier = identifier_it.value(); + + if (json_identifier.is_null()) { + return absl::InvalidArgumentError("json_identifier is null."); + } + if (!json_identifier.is_string()) { + return absl::InvalidArgumentError("Expecting json_identifier.is_string()."); + } + return json_identifier.get(); +} + +absl::StatusOr> +TlVariable::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto identifier, TlVariable::GetIdentifier(json)); + + return absl::make_unique( + std::move(identifier)); +} + +// ============================================================================= +// TlFunctionDefinition +// ============================================================================= + +absl::StatusOr> +TlFunctionDefinition::GetParameter(const nlohmann::json& json) { + auto parameter_it = json.find("parameter"); + if (parameter_it == json.end()) { + return absl::InvalidArgumentError("`parameter` is undefined."); + } + const nlohmann::json& json_parameter = parameter_it.value(); + + if (json_parameter.is_null()) { + return absl::InvalidArgumentError("json_parameter is null."); + } + return TlVariable::FromJson(json_parameter); +} + +absl::StatusOr> +TlFunctionDefinition::GetParameterType(const nlohmann::json& json) { + auto parameter_type_it = json.find("parameterType"); + if (parameter_type_it == json.end()) { + return absl::InvalidArgumentError("`parameterType` is undefined."); + } + const nlohmann::json& json_parameter_type = parameter_type_it.value(); + + if (json_parameter_type.is_null()) { + return absl::InvalidArgumentError("json_parameter_type is null."); + } + return TlType::FromJson(json_parameter_type); +} + +absl::StatusOr> +TlFunctionDefinition::GetBody(const nlohmann::json& json) { + auto body_it = json.find("body"); + if (body_it == json.end()) { + return absl::InvalidArgumentError("`body` is undefined."); + } + const nlohmann::json& json_body = body_it.value(); + + if (json_body.is_null()) { + return absl::InvalidArgumentError("json_body is null."); + } + return TlExpression::FromJson(json_body); +} + +absl::StatusOr> +TlFunctionDefinition::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto parameter, TlFunctionDefinition::GetParameter(json)); + MALDOCA_ASSIGN_OR_RETURN(auto parameter_type, TlFunctionDefinition::GetParameterType(json)); + MALDOCA_ASSIGN_OR_RETURN(auto body, TlFunctionDefinition::GetBody(json)); + + return absl::make_unique( + std::move(parameter), + std::move(parameter_type), + std::move(body)); +} + +// ============================================================================= +// TlFunctionCall +// ============================================================================= + +absl::StatusOr> +TlFunctionCall::GetCaller(const nlohmann::json& json) { + auto caller_it = json.find("caller"); + if (caller_it == json.end()) { + return absl::InvalidArgumentError("`caller` is undefined."); + } + const nlohmann::json& json_caller = caller_it.value(); + + if (json_caller.is_null()) { + return absl::InvalidArgumentError("json_caller is null."); + } + return TlExpression::FromJson(json_caller); +} + +absl::StatusOr> +TlFunctionCall::GetCallee(const nlohmann::json& json) { + auto callee_it = json.find("callee"); + if (callee_it == json.end()) { + return absl::InvalidArgumentError("`callee` is undefined."); + } + const nlohmann::json& json_callee = callee_it.value(); + + if (json_callee.is_null()) { + return absl::InvalidArgumentError("json_callee is null."); + } + return TlExpression::FromJson(json_callee); +} + +absl::StatusOr> +TlFunctionCall::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto caller, TlFunctionCall::GetCaller(json)); + MALDOCA_ASSIGN_OR_RETURN(auto callee, TlFunctionCall::GetCallee(json)); + + return absl::make_unique( + std::move(caller), + std::move(callee)); +} + +// ============================================================================= +// TlLiteralType +// ============================================================================= + +absl::StatusOr> +TlLiteralType::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + + return absl::make_unique( + ); +} + +// ============================================================================= +// TlFunctionType +// ============================================================================= + +absl::StatusOr> +TlFunctionType::GetParameterType(const nlohmann::json& json) { + auto parameter_type_it = json.find("parameterType"); + if (parameter_type_it == json.end()) { + return absl::InvalidArgumentError("`parameterType` is undefined."); + } + const nlohmann::json& json_parameter_type = parameter_type_it.value(); + + if (json_parameter_type.is_null()) { + return absl::InvalidArgumentError("json_parameter_type is null."); + } + return TlType::FromJson(json_parameter_type); +} + +absl::StatusOr> +TlFunctionType::GetBodyType(const nlohmann::json& json) { + auto body_type_it = json.find("bodyType"); + if (body_type_it == json.end()) { + return absl::InvalidArgumentError("`bodyType` is undefined."); + } + const nlohmann::json& json_body_type = body_type_it.value(); + + if (json_body_type.is_null()) { + return absl::InvalidArgumentError("json_body_type is null."); + } + return TlType::FromJson(json_body_type); +} + +absl::StatusOr> +TlFunctionType::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto parameter_type, TlFunctionType::GetParameterType(json)); + MALDOCA_ASSIGN_OR_RETURN(auto body_type, TlFunctionType::GetBodyType(json)); + + return absl::make_unique( + std::move(parameter_type), + std::move(body_type)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/typed_lambda/ast_gen_test.cc b/maldoca/astgen/test/typed_lambda/ast_gen_test.cc new file mode 100644 index 0000000..573dde2 --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast_gen_test.cc @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + TypedLambda, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = "maldoca/astgen/test/" + "typed_lambda/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "typed_lambda/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/typed_lambda", + .expected_ast_header_path = + "maldoca/astgen/test/" + "typed_lambda/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/" + "typed_lambda/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "typed_lambda/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "typed_lambda/ast_from_json.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/typed_lambda/ast_test.cc b/maldoca/astgen/test/typed_lambda/ast_test.cc new file mode 100644 index 0000000..42df64a --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast_test.cc @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "nlohmann/json.hpp" +#include "maldoca/astgen/test/typed_lambda/ast.generated.h" +#include "maldoca/base/testing/status_matchers.h" + +namespace maldoca { +namespace astgen { +namespace { + +// Even though JSON only has one "number" type, nlohmann::json tries to infer +// the most appropriate type. Here it detects that "1" is an integer. +TEST(TypedLambdaAstTest, LiteralInteger) { + constexpr char kJsonString[] = R"( + { + "type": "Literal", + "value": 1 + } + )"; + + nlohmann::json json = nlohmann::json::parse(kJsonString, /*cb=*/nullptr, + /*allow_exceptions=*/false); + + MALDOCA_ASSERT_OK_AND_ASSIGN(auto expr, TlExpression::FromJson(json)); + + auto *literal = dynamic_cast(expr.get()); + ASSERT_NE(literal, nullptr); + + EXPECT_TRUE(std::holds_alternative(literal->value())); +} + +// Even though JSON only has one "number" type, nlohmann::json tries to infer +// the most appropriate type. Here it detects that "1.0" is a double. +TEST(TypedLambdaAstTest, LiteralDouble) { + constexpr char kJsonString[] = R"( + { + "type": "Literal", + "value": 1.0 + } + )"; + + nlohmann::json json = nlohmann::json::parse(kJsonString, /*cb=*/nullptr, + /*allow_exceptions=*/false); + + MALDOCA_ASSERT_OK_AND_ASSIGN(auto expr, TlExpression::FromJson(json)); + + auto *literal = dynamic_cast(expr.get()); + ASSERT_NE(literal, nullptr); + + EXPECT_TRUE(std::holds_alternative(literal->value())); +} + +} // namespace +} // namespace astgen +} // namespace maldoca diff --git a/maldoca/astgen/test/typed_lambda/ast_to_json.generated.cc b/maldoca/astgen/test/typed_lambda/ast_to_json.generated.cc new file mode 100644 index 0000000..784e758 --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast_to_json.generated.cc @@ -0,0 +1,235 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/typed_lambda/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// TlNode +// ============================================================================= + +void TlNode::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +// ============================================================================= +// TlExpression +// ============================================================================= + +void TlExpression::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +// ============================================================================= +// TlType +// ============================================================================= + +void TlType::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +// ============================================================================= +// TlLiteral +// ============================================================================= + +void TlLiteral::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + switch (value_.index()) { + case 0: { + os << "\"value\":" << (nlohmann::json(std::get<0>(value_))).dump(); + break; + } + case 1: { + os << "\"value\":" << (nlohmann::json(std::get<1>(value_))).dump(); + break; + } + case 2: { + os << "\"value\":" << (nlohmann::json(std::get<2>(value_))).dump(); + break; + } + case 3: { + os << "\"value\":" << (nlohmann::json(std::get<3>(value_))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } +} + +void TlLiteral::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"Literal\""; + TlNode::SerializeFields(os, needs_comma); + TlExpression::SerializeFields(os, needs_comma); + TlLiteral::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// TlVariable +// ============================================================================= + +void TlVariable::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"identifier\":" << (nlohmann::json(identifier_)).dump(); +} + +void TlVariable::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"Variable\""; + TlNode::SerializeFields(os, needs_comma); + TlExpression::SerializeFields(os, needs_comma); + TlVariable::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// TlFunctionDefinition +// ============================================================================= + +void TlFunctionDefinition::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"parameter\":"; + parameter_->Serialize(os); + MaybeAddComma(os, needs_comma); + os << "\"parameterType\":"; + parameter_type_->Serialize(os); + MaybeAddComma(os, needs_comma); + os << "\"body\":"; + body_->Serialize(os); +} + +void TlFunctionDefinition::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"FunctionDefinition\""; + TlNode::SerializeFields(os, needs_comma); + TlExpression::SerializeFields(os, needs_comma); + TlFunctionDefinition::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// TlFunctionCall +// ============================================================================= + +void TlFunctionCall::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"caller\":"; + caller_->Serialize(os); + MaybeAddComma(os, needs_comma); + os << "\"callee\":"; + callee_->Serialize(os); +} + +void TlFunctionCall::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"FunctionCall\""; + TlNode::SerializeFields(os, needs_comma); + TlExpression::SerializeFields(os, needs_comma); + TlFunctionCall::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// TlLiteralType +// ============================================================================= + +void TlLiteralType::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +void TlLiteralType::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"LiteralType\""; + TlNode::SerializeFields(os, needs_comma); + TlType::SerializeFields(os, needs_comma); + TlLiteralType::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// TlFunctionType +// ============================================================================= + +void TlFunctionType::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"parameterType\":"; + parameter_type_->Serialize(os); + MaybeAddComma(os, needs_comma); + os << "\"bodyType\":"; + body_type_->Serialize(os); +} + +void TlFunctionType::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"FunctionType\""; + TlNode::SerializeFields(os, needs_comma); + TlType::SerializeFields(os, needs_comma); + TlFunctionType::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/typed_lambda/ast_ts_interface.generated b/maldoca/astgen/test/typed_lambda/ast_ts_interface.generated new file mode 100644 index 0000000..0660f52 --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/ast_ts_interface.generated @@ -0,0 +1,35 @@ +interface Node { +} + +interface Expression <: Node { +} + +interface Type <: Node { +} + +interface Literal <: Expression { + value: boolean | /*int64*/number | /*double*/number | string +} + +interface Variable <: Expression { + identifier: string +} + +interface FunctionDefinition <: Expression { + parameter: Variable + parameterType: Type + body: Expression +} + +interface FunctionCall <: Expression { + caller: Expression + callee: Expression +} + +interface LiteralType <: Type { +} + +interface FunctionType <: Type { + parameterType: Type + bodyType: Type +} diff --git a/maldoca/astgen/test/typed_lambda/tlir_ops.generated.td b/maldoca/astgen/test/typed_lambda/tlir_ops.generated.td new file mode 100644 index 0000000..2c12356 --- /dev/null +++ b/maldoca/astgen/test/typed_lambda/tlir_ops.generated.td @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_TYPED_LAMBDA_TLIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_TYPED_LAMBDA_TLIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/typed_lambda/interfaces.td" +include "maldoca/astgen/test/typed_lambda/tlir_dialect.td" +include "maldoca/astgen/test/typed_lambda/tlir_types.td" + +#endif // MALDOCA_ASTGEN_TEST_TYPED_LAMBDA_TLIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/union/BUILD b/maldoca/astgen/test/union/BUILD new file mode 100644 index 0000000..5121f35 --- /dev/null +++ b/maldoca/astgen/test/union/BUILD @@ -0,0 +1,65 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//maldoca/astgen:__subpackages__", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = [ + "ast.generated.h", + ], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) diff --git a/maldoca/astgen/test/union/ast.generated.cc b/maldoca/astgen/test/union/ast.generated.cc new file mode 100644 index 0000000..e3eb1fd --- /dev/null +++ b/maldoca/astgen/test/union/ast.generated.cc @@ -0,0 +1,139 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/union/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// EUnionType +// ============================================================================= + +absl::string_view EUnionTypeTypeToString(EUnionTypeType union_type_type) { + switch (union_type_type) { + case EUnionTypeType::kSubNodeA: + return "SubNodeA"; + case EUnionTypeType::kSubNodeB: + return "SubNodeB"; + } +} + +absl::StatusOr StringToEUnionTypeType(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"SubNodeA", EUnionTypeType::kSubNodeA}, + {"SubNodeB", EUnionTypeType::kSubNodeB}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for EUnionTypeType: ", s)); + } + return it->second; +} + +// ============================================================================= +// ENode +// ============================================================================= + +ENode::ENode( + std::string name, + std::unique_ptr content) + : name_(std::move(name)), + content_(std::move(content)) {} + +absl::string_view ENode::name() const { + return name_; +} + +void ENode::set_name(std::string name) { + name_ = std::move(name); +} + +EUnionType* ENode::content() { + return content_.get(); +} + +const EUnionType* ENode::content() const { + return content_.get(); +} + +void ENode::set_content(std::unique_ptr content) { + content_ = std::move(content); +} + +// ============================================================================= +// ESubNodeA +// ============================================================================= + +ESubNodeA::ESubNodeA( + std::string value_a) + : EUnionType(), + value_a_(std::move(value_a)) {} + +absl::string_view ESubNodeA::value_a() const { + return value_a_; +} + +void ESubNodeA::set_value_a(std::string value_a) { + value_a_ = std::move(value_a); +} + +// ============================================================================= +// ESubNodeB +// ============================================================================= + +ESubNodeB::ESubNodeB( + std::string value_b) + : EUnionType(), + value_b_(std::move(value_b)) {} + +absl::string_view ESubNodeB::value_b() const { + return value_b_; +} + +void ESubNodeB::set_value_b(std::string value_b) { + value_b_ = std::move(value_b); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/union/ast.generated.h b/maldoca/astgen/test/union/ast.generated.h new file mode 100644 index 0000000..1997780 --- /dev/null +++ b/maldoca/astgen/test/union/ast.generated.h @@ -0,0 +1,161 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_UNION_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_UNION_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +enum class EUnionTypeType { + kSubNodeA, + kSubNodeB, +}; + +absl::string_view EUnionTypeTypeToString(EUnionTypeType union_type_type); +absl::StatusOr StringToEUnionTypeType(absl::string_view s); + +class EUnionType { + public: + virtual ~EUnionType() = default; + + virtual EUnionTypeType union_type_type() const = 0; + + virtual void Serialize(std::ostream& os) const = 0; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class ENode { + public: + explicit ENode( + std::string name, + std::unique_ptr content); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + absl::string_view name() const; + void set_name(std::string name); + + EUnionType* content(); + const EUnionType* content() const; + void set_content(std::unique_ptr content); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetName(const nlohmann::json& json); + static absl::StatusOr> GetContent(const nlohmann::json& json); + + private: + std::string name_; + std::unique_ptr content_; +}; + +class ESubNodeA : public virtual EUnionType { + public: + explicit ESubNodeA( + std::string value_a); + + EUnionTypeType union_type_type() const override { + return EUnionTypeType::kSubNodeA; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + absl::string_view value_a() const; + void set_value_a(std::string value_a); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetValueA(const nlohmann::json& json); + + private: + std::string value_a_; +}; + +class ESubNodeB : public virtual EUnionType { + public: + explicit ESubNodeB( + std::string value_b); + + EUnionTypeType union_type_type() const override { + return EUnionTypeType::kSubNodeB; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + absl::string_view value_b() const; + void set_value_b(std::string value_b); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr GetValueB(const nlohmann::json& json); + + private: + std::string value_b_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_UNION_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/union/ast_def.textproto b/maldoca/astgen/test/union/ast_def.textproto new file mode 100644 index 0000000..c7f9602 --- /dev/null +++ b/maldoca/astgen/test/union/ast_def.textproto @@ -0,0 +1,44 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "e" + +nodes { + name: "Node" + fields { + name: "name" + type { string {} } + kind: FIELD_KIND_ATTR + } + fields { + name: "content" + type { class: "UnionType" } + kind: FIELD_KIND_ATTR + } +} + +nodes { + name: "SubNodeA" + type: "SubNodeA" + fields { + name: "valueA" + type { string {} } + kind: FIELD_KIND_ATTR + } +} + +nodes { + name: "SubNodeB" + type: "SubNodeB" + fields { + name: "valueB" + type { string {} } + kind: FIELD_KIND_ATTR + } +} + +union_types { + name: "UnionType" + types: "SubNodeA" + types: "SubNodeB" +} diff --git a/maldoca/astgen/test/union/ast_from_json.generated.cc b/maldoca/astgen/test/union/ast_from_json.generated.cc new file mode 100644 index 0000000..6b70915 --- /dev/null +++ b/maldoca/astgen/test/union/ast_from_json.generated.cc @@ -0,0 +1,197 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/union/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +static absl::StatusOr GetType(const nlohmann::json& json) { + auto type_it = json.find("type"); + if (type_it == json.end()) { + return absl::InvalidArgumentError("`type` is undefined."); + } + const nlohmann::json& json_type = type_it.value(); + if (json_type.is_null()) { + return absl::InvalidArgumentError("json_type is null."); + } + if (!json_type.is_string()) { + return absl::InvalidArgumentError("`json_type` expected to be string."); + } + return json_type.get(); +} + +// ============================================================================= +// EUnionType +// ============================================================================= + +absl::StatusOr> +EUnionType::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "SubNodeA") { + return ESubNodeA::FromJson(json); + } else if (type == "SubNodeB") { + return ESubNodeB::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// ENode +// ============================================================================= + +absl::StatusOr +ENode::GetName(const nlohmann::json& json) { + auto name_it = json.find("name"); + if (name_it == json.end()) { + return absl::InvalidArgumentError("`name` is undefined."); + } + const nlohmann::json& json_name = name_it.value(); + + if (json_name.is_null()) { + return absl::InvalidArgumentError("json_name is null."); + } + if (!json_name.is_string()) { + return absl::InvalidArgumentError("Expecting json_name.is_string()."); + } + return json_name.get(); +} + +absl::StatusOr> +ENode::GetContent(const nlohmann::json& json) { + auto content_it = json.find("content"); + if (content_it == json.end()) { + return absl::InvalidArgumentError("`content` is undefined."); + } + const nlohmann::json& json_content = content_it.value(); + + if (json_content.is_null()) { + return absl::InvalidArgumentError("json_content is null."); + } + return EUnionType::FromJson(json_content); +} + +absl::StatusOr> +ENode::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto name, ENode::GetName(json)); + MALDOCA_ASSIGN_OR_RETURN(auto content, ENode::GetContent(json)); + + return absl::make_unique( + std::move(name), + std::move(content)); +} + +// ============================================================================= +// ESubNodeA +// ============================================================================= + +absl::StatusOr +ESubNodeA::GetValueA(const nlohmann::json& json) { + auto value_a_it = json.find("valueA"); + if (value_a_it == json.end()) { + return absl::InvalidArgumentError("`valueA` is undefined."); + } + const nlohmann::json& json_value_a = value_a_it.value(); + + if (json_value_a.is_null()) { + return absl::InvalidArgumentError("json_value_a is null."); + } + if (!json_value_a.is_string()) { + return absl::InvalidArgumentError("Expecting json_value_a.is_string()."); + } + return json_value_a.get(); +} + +absl::StatusOr> +ESubNodeA::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto value_a, ESubNodeA::GetValueA(json)); + + return absl::make_unique( + std::move(value_a)); +} + +// ============================================================================= +// ESubNodeB +// ============================================================================= + +absl::StatusOr +ESubNodeB::GetValueB(const nlohmann::json& json) { + auto value_b_it = json.find("valueB"); + if (value_b_it == json.end()) { + return absl::InvalidArgumentError("`valueB` is undefined."); + } + const nlohmann::json& json_value_b = value_b_it.value(); + + if (json_value_b.is_null()) { + return absl::InvalidArgumentError("json_value_b is null."); + } + if (!json_value_b.is_string()) { + return absl::InvalidArgumentError("Expecting json_value_b.is_string()."); + } + return json_value_b.get(); +} + +absl::StatusOr> +ESubNodeB::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto value_b, ESubNodeB::GetValueB(json)); + + return absl::make_unique( + std::move(value_b)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/union/ast_gen_test.cc b/maldoca/astgen/test/union/ast_gen_test.cc new file mode 100644 index 0000000..0a31bbc --- /dev/null +++ b/maldoca/astgen/test/union/ast_gen_test.cc @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + Lambda, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = + "maldoca/astgen/test/union/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "union/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/union", + .ir_path = "maldoca/astgen/test/union", + .expected_ast_header_path = + "maldoca/astgen/test/union/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/union/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "union/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "union/ast_from_json.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/union/ast_to_json.generated.cc b/maldoca/astgen/test/union/ast_to_json.generated.cc new file mode 100644 index 0000000..28f8437 --- /dev/null +++ b/maldoca/astgen/test/union/ast_to_json.generated.cc @@ -0,0 +1,120 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/union/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// EUnionType +// ============================================================================= + +void EUnionType::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +// ============================================================================= +// ENode +// ============================================================================= + +void ENode::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"name\":" << (nlohmann::json(name_)).dump(); + MaybeAddComma(os, needs_comma); + os << "\"content\":"; + content_->Serialize(os); +} + +void ENode::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + ENode::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// ESubNodeA +// ============================================================================= + +void ESubNodeA::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"valueA\":" << (nlohmann::json(value_a_)).dump(); +} + +void ESubNodeA::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"SubNodeA\""; + EUnionType::SerializeFields(os, needs_comma); + ESubNodeA::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// ESubNodeB +// ============================================================================= + +void ESubNodeB::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + os << "\"valueB\":" << (nlohmann::json(value_b_)).dump(); +} + +void ESubNodeB::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"SubNodeB\""; + EUnionType::SerializeFields(os, needs_comma); + ESubNodeB::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/union/ast_ts_interface.generated b/maldoca/astgen/test/union/ast_ts_interface.generated new file mode 100644 index 0000000..129470a --- /dev/null +++ b/maldoca/astgen/test/union/ast_ts_interface.generated @@ -0,0 +1,15 @@ +interface Node { + name: string + content: UnionType +} + +interface SubNodeA <: UnionType { + valueA: string +} + +interface SubNodeB <: UnionType { + valueB: string +} + +interface UnionType { +} diff --git a/maldoca/astgen/test/union/eir_ops.generated.td b/maldoca/astgen/test/union/eir_ops.generated.td new file mode 100644 index 0000000..7687265 --- /dev/null +++ b/maldoca/astgen/test/union/eir_ops.generated.td @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_UNION_EIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_UNION_EIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/union/interfaces.td" +include "maldoca/astgen/test/union/eir_dialect.td" +include "maldoca/astgen/test/union/eir_types.td" + +#endif // MALDOCA_ASTGEN_TEST_UNION_EIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/variant/BUILD b/maldoca/astgen/test/variant/BUILD new file mode 100644 index 0000000..28ecad4 --- /dev/null +++ b/maldoca/astgen/test/variant/BUILD @@ -0,0 +1,208 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +licenses(["notice"]) + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//maldoca/astgen:__subpackages__", + ], +) + +cc_test( + name = "ast_gen_test", + srcs = ["ast_gen_test.cc"], + data = [ + "ast.generated.cc", + "ast.generated.h", + "ast_def.textproto", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + "ast_ts_interface.generated", + "vir_ops.generated.td", + "//maldoca/astgen/test/variant/conversion:ast_to_vir.generated.cc", + "//maldoca/astgen/test/variant/conversion:vir_to_ast.generated.cc", + ], + deps = [ + "//maldoca/astgen/test:ast_gen_test_util", + "@googletest//:gtest_main", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.generated.cc", + "ast_from_json.generated.cc", + "ast_to_json.generated.cc", + ], + hdrs = ["ast.generated.h"], + deps = [ + "//maldoca/base:status", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@nlohmann_json//:json", + ], +) + +td_library( + name = "interfaces_td_files", + srcs = [ + "interfaces.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "interfaces_inc_gen", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "interfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "interfaces.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "interfaces.td", + deps = [":interfaces_td_files"], +) + +td_library( + name = "vir_dialect_td_files", + srcs = [ + "vir_dialect.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "vir_dialect_inc_gen", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + "-dialect=vir", + ], + "vir_dialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=vir", + ], + "vir_dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vir_dialect.td", + deps = [":vir_dialect_td_files"], +) + +td_library( + name = "vir_types_td_files", + srcs = [ + "vir_types.td", + ], + deps = [ + ":vir_dialect_td_files", + ], +) + +gentbl_cc_library( + name = "vir_types_inc_gen", + tbl_outs = [ + ( + ["-gen-typedef-decls"], + "vir_types.h.inc", + ), + ( + ["-gen-typedef-defs"], + "vir_types.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vir_types.td", + deps = [":vir_types_td_files"], +) + +td_library( + name = "vir_ops_generated_td_files", + srcs = [ + "vir_ops.generated.td", + ], + deps = [ + ":interfaces_td_files", + ":vir_dialect_td_files", + ":vir_types_td_files", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "vir_ops_generated_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "vir_ops.generated.h.inc", + ), + ( + ["-gen-op-defs"], + "vir_ops.generated.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vir_ops.generated.td", + deps = [":vir_ops_generated_td_files"], +) + +cc_library( + name = "ir", + srcs = ["ir.cc"], + hdrs = ["ir.h"], + deps = [ + ":interfaces_inc_gen", + ":vir_dialect_inc_gen", + ":vir_ops_generated_inc_gen", + ":vir_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + ], +) diff --git a/maldoca/astgen/test/variant/ast.generated.cc b/maldoca/astgen/test/variant/ast.generated.cc new file mode 100644 index 0000000..f6e69f4 --- /dev/null +++ b/maldoca/astgen/test/variant/ast.generated.cc @@ -0,0 +1,266 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#include "maldoca/astgen/test/variant/ast.generated.h" + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +// ============================================================================= +// VBaseClass +// ============================================================================= + +absl::string_view VBaseClassTypeToString(VBaseClassType base_class_type) { + switch (base_class_type) { + case VBaseClassType::kDerivedClass1: + return "DerivedClass1"; + case VBaseClassType::kDerivedClass2: + return "DerivedClass2"; + } +} + +absl::StatusOr StringToVBaseClassType(absl::string_view s) { + static const auto *kMap = new absl::flat_hash_map { + {"DerivedClass1", VBaseClassType::kDerivedClass1}, + {"DerivedClass2", VBaseClassType::kDerivedClass2}, + }; + + auto it = kMap->find(s); + if (it == kMap->end()) { + return absl::InvalidArgumentError(absl::StrCat("Invalid string for VBaseClassType: ", s)); + } + return it->second; +} + +// ============================================================================= +// VDerivedClass1 +// ============================================================================= + +// ============================================================================= +// VDerivedClass2 +// ============================================================================= + +// ============================================================================= +// VNode +// ============================================================================= + +VNode::VNode( + std::variant simple_variant_builtin, + std::optional> nullable_variant_builtin, + std::optional> optional_variant_builtin, + std::variant, std::unique_ptr> simple_variant_class, + std::optional, std::unique_ptr>> nullable_variant_class, + std::optional, std::unique_ptr>> optional_variant_class) + : simple_variant_builtin_(std::move(simple_variant_builtin)), + nullable_variant_builtin_(std::move(nullable_variant_builtin)), + optional_variant_builtin_(std::move(optional_variant_builtin)), + simple_variant_class_(std::move(simple_variant_class)), + nullable_variant_class_(std::move(nullable_variant_class)), + optional_variant_class_(std::move(optional_variant_class)) {} + +std::variant VNode::simple_variant_builtin() const { + switch (simple_variant_builtin_.index()) { + case 0: { + return std::get<0>(simple_variant_builtin_); + } + case 1: { + return std::get<1>(simple_variant_builtin_); + } + default: + LOG(FATAL) << "Unreachable code."; + } +} + +void VNode::set_simple_variant_builtin(std::variant simple_variant_builtin) { + simple_variant_builtin_ = std::move(simple_variant_builtin); +} + +std::optional> VNode::nullable_variant_builtin() const { + if (!nullable_variant_builtin_.has_value()) { + return std::nullopt; + } else { + switch (nullable_variant_builtin_.value().index()) { + case 0: { + return std::get<0>(nullable_variant_builtin_.value()); + } + case 1: { + return std::get<1>(nullable_variant_builtin_.value()); + } + default: + LOG(FATAL) << "Unreachable code."; + } + } +} + +void VNode::set_nullable_variant_builtin(std::optional> nullable_variant_builtin) { + nullable_variant_builtin_ = std::move(nullable_variant_builtin); +} + +std::optional> VNode::optional_variant_builtin() const { + if (!optional_variant_builtin_.has_value()) { + return std::nullopt; + } else { + switch (optional_variant_builtin_.value().index()) { + case 0: { + return std::get<0>(optional_variant_builtin_.value()); + } + case 1: { + return std::get<1>(optional_variant_builtin_.value()); + } + default: + LOG(FATAL) << "Unreachable code."; + } + } +} + +void VNode::set_optional_variant_builtin(std::optional> optional_variant_builtin) { + optional_variant_builtin_ = std::move(optional_variant_builtin); +} + +std::variant VNode::simple_variant_class() { + switch (simple_variant_class_.index()) { + case 0: { + return std::get<0>(simple_variant_class_).get(); + } + case 1: { + return std::get<1>(simple_variant_class_).get(); + } + default: + LOG(FATAL) << "Unreachable code."; + } +} + +std::variant VNode::simple_variant_class() const { + switch (simple_variant_class_.index()) { + case 0: { + return std::get<0>(simple_variant_class_).get(); + } + case 1: { + return std::get<1>(simple_variant_class_).get(); + } + default: + LOG(FATAL) << "Unreachable code."; + } +} + +void VNode::set_simple_variant_class(std::variant, std::unique_ptr> simple_variant_class) { + simple_variant_class_ = std::move(simple_variant_class); +} + +std::optional> VNode::nullable_variant_class() { + if (!nullable_variant_class_.has_value()) { + return std::nullopt; + } else { + switch (nullable_variant_class_.value().index()) { + case 0: { + return std::get<0>(nullable_variant_class_.value()).get(); + } + case 1: { + return std::get<1>(nullable_variant_class_.value()).get(); + } + default: + LOG(FATAL) << "Unreachable code."; + } + } +} + +std::optional> VNode::nullable_variant_class() const { + if (!nullable_variant_class_.has_value()) { + return std::nullopt; + } else { + switch (nullable_variant_class_.value().index()) { + case 0: { + return std::get<0>(nullable_variant_class_.value()).get(); + } + case 1: { + return std::get<1>(nullable_variant_class_.value()).get(); + } + default: + LOG(FATAL) << "Unreachable code."; + } + } +} + +void VNode::set_nullable_variant_class(std::optional, std::unique_ptr>> nullable_variant_class) { + nullable_variant_class_ = std::move(nullable_variant_class); +} + +std::optional> VNode::optional_variant_class() { + if (!optional_variant_class_.has_value()) { + return std::nullopt; + } else { + switch (optional_variant_class_.value().index()) { + case 0: { + return std::get<0>(optional_variant_class_.value()).get(); + } + case 1: { + return std::get<1>(optional_variant_class_.value()).get(); + } + default: + LOG(FATAL) << "Unreachable code."; + } + } +} + +std::optional> VNode::optional_variant_class() const { + if (!optional_variant_class_.has_value()) { + return std::nullopt; + } else { + switch (optional_variant_class_.value().index()) { + case 0: { + return std::get<0>(optional_variant_class_.value()).get(); + } + case 1: { + return std::get<1>(optional_variant_class_.value()).get(); + } + default: + LOG(FATAL) << "Unreachable code."; + } + } +} + +void VNode::set_optional_variant_class(std::optional, std::unique_ptr>> optional_variant_class) { + optional_variant_class_ = std::move(optional_variant_class); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/variant/ast.generated.h b/maldoca/astgen/test/variant/ast.generated.h new file mode 100644 index 0000000..9877dc7 --- /dev/null +++ b/maldoca/astgen/test/variant/ast.generated.h @@ -0,0 +1,161 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_VARIANT_AST_GENERATED_H_ +#define MALDOCA_ASTGEN_TEST_VARIANT_AST_GENERATED_H_ + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +enum class VBaseClassType { + kDerivedClass1, + kDerivedClass2, +}; + +absl::string_view VBaseClassTypeToString(VBaseClassType base_class_type); +absl::StatusOr StringToVBaseClassType(absl::string_view s); + +class VBaseClass { + public: + virtual ~VBaseClass() = default; + + virtual VBaseClassType base_class_type() const = 0; + + virtual void Serialize(std::ostream& os) const = 0; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class VDerivedClass1 : public virtual VBaseClass { + public: + VBaseClassType base_class_type() const override { + return VBaseClassType::kDerivedClass1; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class VDerivedClass2 : public virtual VBaseClass { + public: + VBaseClassType base_class_type() const override { + return VBaseClassType::kDerivedClass2; + } + + void Serialize(std::ostream& os) const override; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; +}; + +class VNode { + public: + explicit VNode( + std::variant simple_variant_builtin, + std::optional> nullable_variant_builtin, + std::optional> optional_variant_builtin, + std::variant, std::unique_ptr> simple_variant_class, + std::optional, std::unique_ptr>> nullable_variant_class, + std::optional, std::unique_ptr>> optional_variant_class); + + void Serialize(std::ostream& os) const; + + static absl::StatusOr> FromJson(const nlohmann::json& json); + + std::variant simple_variant_builtin() const; + void set_simple_variant_builtin(std::variant simple_variant_builtin); + + std::optional> nullable_variant_builtin() const; + void set_nullable_variant_builtin(std::optional> nullable_variant_builtin); + + std::optional> optional_variant_builtin() const; + void set_optional_variant_builtin(std::optional> optional_variant_builtin); + + std::variant simple_variant_class(); + std::variant simple_variant_class() const; + void set_simple_variant_class(std::variant, std::unique_ptr> simple_variant_class); + + std::optional> nullable_variant_class(); + std::optional> nullable_variant_class() const; + void set_nullable_variant_class(std::optional, std::unique_ptr>> nullable_variant_class); + + std::optional> optional_variant_class(); + std::optional> optional_variant_class() const; + void set_optional_variant_class(std::optional, std::unique_ptr>> optional_variant_class); + + protected: + // Internal function used by Serialize(). + // Sets the fields defined in this class. + // Does not set fields defined in ancestors. + void SerializeFields(std::ostream& os, bool &needs_comma) const; + + // Internal functions used by FromJson(). + // Extracts a field from a JSON object. + static absl::StatusOr> GetSimpleVariantBuiltin(const nlohmann::json& json); + static absl::StatusOr>> GetNullableVariantBuiltin(const nlohmann::json& json); + static absl::StatusOr>> GetOptionalVariantBuiltin(const nlohmann::json& json); + static absl::StatusOr, std::unique_ptr>> GetSimpleVariantClass(const nlohmann::json& json); + static absl::StatusOr, std::unique_ptr>>> GetNullableVariantClass(const nlohmann::json& json); + static absl::StatusOr, std::unique_ptr>>> GetOptionalVariantClass(const nlohmann::json& json); + + private: + std::variant simple_variant_builtin_; + std::optional> nullable_variant_builtin_; + std::optional> optional_variant_builtin_; + std::variant, std::unique_ptr> simple_variant_class_; + std::optional, std::unique_ptr>> nullable_variant_class_; + std::optional, std::unique_ptr>> optional_variant_class_; +}; + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_VARIANT_AST_GENERATED_H_ diff --git a/maldoca/astgen/test/variant/ast_def.textproto b/maldoca/astgen/test/variant/ast_def.textproto new file mode 100644 index 0000000..40c18ec --- /dev/null +++ b/maldoca/astgen/test/variant/ast_def.textproto @@ -0,0 +1,106 @@ +# proto-file: maldoca/astgen/ast_def.proto +# proto-message: AstDefPb + +lang_name: "v" + +# interface BaseClass {} +nodes { + name: "BaseClass" + kinds: FIELD_KIND_RVAL +} + +# interface DerivedClass1 {} +nodes { + name: "DerivedClass1" + parents: "BaseClass" + type: "DerivedClass1" + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface DerivedClass2 {} +nodes { + name: "DerivedClass2" + parents: "BaseClass" + type: "DerivedClass2" + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} + +# interface Node { +# simpleVariantBuiltin: number | string +# nullableVariantBuiltin : number | string | null +# optionalVariantBuiltin? : number | string +# simpleVariantClass: DerivedClass1 | DerivedClass2 +# nullableVariantClass: DerivedClass1 | DerivedClass2 | null +# optionalVariantClass?: DerivedClass1 | DerivedClass2 +# } +nodes { + name: "Node" + fields { + name: "simpleVariantBuiltin" + type { + variant { + types { double {} } + types { string {} } + } + } + kind: FIELD_KIND_ATTR + } + fields { + name: "nullableVariantBuiltin" + optionalness: OPTIONALNESS_MAYBE_NULL + type { + variant { + types { double {} } + types { string {} } + } + } + kind: FIELD_KIND_ATTR + } + fields { + name: "optionalVariantBuiltin" + optionalness: OPTIONALNESS_MAYBE_UNDEFINED + type { + variant { + types { double {} } + types { string {} } + } + } + kind: FIELD_KIND_ATTR + } + fields { + name: "simpleVariantClass" + type { + variant { + types { class: "DerivedClass1" } + types { class: "DerivedClass2" } + } + } + kind: FIELD_KIND_RVAL + } + fields { + name: "nullableVariantClass" + optionalness: OPTIONALNESS_MAYBE_NULL + type { + variant { + types { class: "DerivedClass1" } + types { class: "DerivedClass2" } + } + } + kind: FIELD_KIND_RVAL + } + fields { + name: "optionalVariantClass" + optionalness: OPTIONALNESS_MAYBE_UNDEFINED + type { + variant { + types { class: "DerivedClass1" } + types { class: "DerivedClass2" } + } + } + kind: FIELD_KIND_RVAL + } + kinds: FIELD_KIND_RVAL + should_generate_ir_op: true +} diff --git a/maldoca/astgen/test/variant/ast_from_json.generated.cc b/maldoca/astgen/test/variant/ast_from_json.generated.cc new file mode 100644 index 0000000..d98260d --- /dev/null +++ b/maldoca/astgen/test/variant/ast_from_json.generated.cc @@ -0,0 +1,308 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// NOLINTBEGIN(whitespace/line_length) +// clang-format off +// IWYU pragma: begin_keep + +#include +#include +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/variant/ast.generated.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "maldoca/base/status_macros.h" +#include "nlohmann/json.hpp" + +namespace maldoca { + +static absl::StatusOr GetType(const nlohmann::json& json) { + auto type_it = json.find("type"); + if (type_it == json.end()) { + return absl::InvalidArgumentError("`type` is undefined."); + } + const nlohmann::json& json_type = type_it.value(); + if (json_type.is_null()) { + return absl::InvalidArgumentError("json_type is null."); + } + if (!json_type.is_string()) { + return absl::InvalidArgumentError("`json_type` expected to be string."); + } + return json_type.get(); +} + +// ============================================================================= +// VBaseClass +// ============================================================================= + +absl::StatusOr> +VBaseClass::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(std::string type, GetType(json)); + + if (type == "DerivedClass1") { + return VDerivedClass1::FromJson(json); + } else if (type == "DerivedClass2") { + return VDerivedClass2::FromJson(json); + } + return absl::InvalidArgumentError(absl::StrCat("Invalid type: ", type)); +} + +// ============================================================================= +// VDerivedClass1 +// ============================================================================= + +static bool IsDerivedClass1(const nlohmann::json& json) { + if (!json.is_object()) { + return false; + } + auto type_it = json.find("type"); + if (type_it == json.end()) { + return false; + } + const nlohmann::json &type_json = type_it.value(); + if (!type_json.is_string()) { + return false; + } + const std::string &type = type_json.get(); + return type == "DerivedClass1"; +} + +absl::StatusOr> +VDerivedClass1::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + + return absl::make_unique( + ); +} + +// ============================================================================= +// VDerivedClass2 +// ============================================================================= + +static bool IsDerivedClass2(const nlohmann::json& json) { + if (!json.is_object()) { + return false; + } + auto type_it = json.find("type"); + if (type_it == json.end()) { + return false; + } + const nlohmann::json &type_json = type_it.value(); + if (!type_json.is_string()) { + return false; + } + const std::string &type = type_json.get(); + return type == "DerivedClass2"; +} + +absl::StatusOr> +VDerivedClass2::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + + return absl::make_unique( + ); +} + +// ============================================================================= +// VNode +// ============================================================================= + +absl::StatusOr> +VNode::GetSimpleVariantBuiltin(const nlohmann::json& json) { + auto simple_variant_builtin_it = json.find("simpleVariantBuiltin"); + if (simple_variant_builtin_it == json.end()) { + return absl::InvalidArgumentError("`simpleVariantBuiltin` is undefined."); + } + const nlohmann::json& json_simple_variant_builtin = simple_variant_builtin_it.value(); + + if (json_simple_variant_builtin.is_null()) { + return absl::InvalidArgumentError("json_simple_variant_builtin is null."); + } + if (json_simple_variant_builtin.is_number()) { + return json_simple_variant_builtin.get(); + } else if (json_simple_variant_builtin.is_string()) { + return json_simple_variant_builtin.get(); + } else { + auto result = absl::InvalidArgumentError("json_simple_variant_builtin has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_simple_variant_builtin.dump()}); + return result; + } +} + +absl::StatusOr>> +VNode::GetNullableVariantBuiltin(const nlohmann::json& json) { + auto nullable_variant_builtin_it = json.find("nullableVariantBuiltin"); + if (nullable_variant_builtin_it == json.end()) { + return absl::InvalidArgumentError("`nullableVariantBuiltin` is undefined."); + } + const nlohmann::json& json_nullable_variant_builtin = nullable_variant_builtin_it.value(); + + if (json_nullable_variant_builtin.is_null()) { + return std::nullopt; + } + if (json_nullable_variant_builtin.is_number()) { + return json_nullable_variant_builtin.get(); + } else if (json_nullable_variant_builtin.is_string()) { + return json_nullable_variant_builtin.get(); + } else { + auto result = absl::InvalidArgumentError("json_nullable_variant_builtin has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_nullable_variant_builtin.dump()}); + return result; + } +} + +absl::StatusOr>> +VNode::GetOptionalVariantBuiltin(const nlohmann::json& json) { + auto optional_variant_builtin_it = json.find("optionalVariantBuiltin"); + if (optional_variant_builtin_it == json.end()) { + return std::nullopt; + } + const nlohmann::json& json_optional_variant_builtin = optional_variant_builtin_it.value(); + + if (json_optional_variant_builtin.is_null()) { + return absl::InvalidArgumentError("json_optional_variant_builtin is null."); + } + if (json_optional_variant_builtin.is_number()) { + return json_optional_variant_builtin.get(); + } else if (json_optional_variant_builtin.is_string()) { + return json_optional_variant_builtin.get(); + } else { + auto result = absl::InvalidArgumentError("json_optional_variant_builtin has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_optional_variant_builtin.dump()}); + return result; + } +} + +absl::StatusOr, std::unique_ptr>> +VNode::GetSimpleVariantClass(const nlohmann::json& json) { + auto simple_variant_class_it = json.find("simpleVariantClass"); + if (simple_variant_class_it == json.end()) { + return absl::InvalidArgumentError("`simpleVariantClass` is undefined."); + } + const nlohmann::json& json_simple_variant_class = simple_variant_class_it.value(); + + if (json_simple_variant_class.is_null()) { + return absl::InvalidArgumentError("json_simple_variant_class is null."); + } + if (IsDerivedClass1(json_simple_variant_class)) { + return VDerivedClass1::FromJson(json_simple_variant_class); + } else if (IsDerivedClass2(json_simple_variant_class)) { + return VDerivedClass2::FromJson(json_simple_variant_class); + } else { + auto result = absl::InvalidArgumentError("json_simple_variant_class has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_simple_variant_class.dump()}); + return result; + } +} + +absl::StatusOr, std::unique_ptr>>> +VNode::GetNullableVariantClass(const nlohmann::json& json) { + auto nullable_variant_class_it = json.find("nullableVariantClass"); + if (nullable_variant_class_it == json.end()) { + return absl::InvalidArgumentError("`nullableVariantClass` is undefined."); + } + const nlohmann::json& json_nullable_variant_class = nullable_variant_class_it.value(); + + if (json_nullable_variant_class.is_null()) { + return std::nullopt; + } + if (IsDerivedClass1(json_nullable_variant_class)) { + return VDerivedClass1::FromJson(json_nullable_variant_class); + } else if (IsDerivedClass2(json_nullable_variant_class)) { + return VDerivedClass2::FromJson(json_nullable_variant_class); + } else { + auto result = absl::InvalidArgumentError("json_nullable_variant_class has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_nullable_variant_class.dump()}); + return result; + } +} + +absl::StatusOr, std::unique_ptr>>> +VNode::GetOptionalVariantClass(const nlohmann::json& json) { + auto optional_variant_class_it = json.find("optionalVariantClass"); + if (optional_variant_class_it == json.end()) { + return std::nullopt; + } + const nlohmann::json& json_optional_variant_class = optional_variant_class_it.value(); + + if (json_optional_variant_class.is_null()) { + return absl::InvalidArgumentError("json_optional_variant_class is null."); + } + if (IsDerivedClass1(json_optional_variant_class)) { + return VDerivedClass1::FromJson(json_optional_variant_class); + } else if (IsDerivedClass2(json_optional_variant_class)) { + return VDerivedClass2::FromJson(json_optional_variant_class); + } else { + auto result = absl::InvalidArgumentError("json_optional_variant_class has invalid type."); + result.SetPayload("json", absl::Cord{json.dump()}); + result.SetPayload("json_element", absl::Cord{json_optional_variant_class.dump()}); + return result; + } +} + +absl::StatusOr> +VNode::FromJson(const nlohmann::json& json) { + if (!json.is_object()) { + return absl::InvalidArgumentError("JSON is not an object."); + } + + MALDOCA_ASSIGN_OR_RETURN(auto simple_variant_builtin, VNode::GetSimpleVariantBuiltin(json)); + MALDOCA_ASSIGN_OR_RETURN(auto nullable_variant_builtin, VNode::GetNullableVariantBuiltin(json)); + MALDOCA_ASSIGN_OR_RETURN(auto optional_variant_builtin, VNode::GetOptionalVariantBuiltin(json)); + MALDOCA_ASSIGN_OR_RETURN(auto simple_variant_class, VNode::GetSimpleVariantClass(json)); + MALDOCA_ASSIGN_OR_RETURN(auto nullable_variant_class, VNode::GetNullableVariantClass(json)); + MALDOCA_ASSIGN_OR_RETURN(auto optional_variant_class, VNode::GetOptionalVariantClass(json)); + + return absl::make_unique( + std::move(simple_variant_builtin), + std::move(nullable_variant_builtin), + std::move(optional_variant_builtin), + std::move(simple_variant_class), + std::move(nullable_variant_class), + std::move(optional_variant_class)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/variant/ast_gen_test.cc b/maldoca/astgen/test/variant/ast_gen_test.cc new file mode 100644 index 0000000..e3302e7 --- /dev/null +++ b/maldoca/astgen/test/variant/ast_gen_test.cc @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/ast_gen_test_util.h" + +namespace maldoca { +namespace { + +INSTANTIATE_TEST_SUITE_P( + Variant, AstGenTest, + ::testing::Values(AstGenTestParam{ + .ast_def_path = + "maldoca/astgen/test/variant/ast_def.textproto", + .ts_interface_path = "maldoca/astgen/test/" + "variant/ast_ts_interface.generated", + .cc_namespace = "maldoca", + .ast_path = "maldoca/astgen/test/variant", + .ir_path = "maldoca/astgen/test/variant", + .expected_ast_header_path = + "maldoca/astgen/test/variant/ast.generated.h", + .expected_ast_source_path = + "maldoca/astgen/test/variant/ast.generated.cc", + .expected_ast_to_json_path = + "maldoca/astgen/test/" + "variant/ast_to_json.generated.cc", + .expected_ast_from_json_path = + "maldoca/astgen/test/" + "variant/ast_from_json.generated.cc", + .expected_ir_tablegen_path = + "maldoca/astgen/test/" + "variant/vir_ops.generated.td", + .expected_ast_to_ir_source_path = + "maldoca/astgen/test/" + "variant/conversion/ast_to_vir.generated.cc", + .expected_ir_to_ast_source_path = + "maldoca/astgen/test/" + "variant/conversion/vir_to_ast.generated.cc", + })); + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/variant/ast_to_json.generated.cc b/maldoca/astgen/test/variant/ast_to_json.generated.cc new file mode 100644 index 0000000..7702a1b --- /dev/null +++ b/maldoca/astgen/test/variant/ast_to_json.generated.cc @@ -0,0 +1,207 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include +#include +#include +#include +#include + +#include "maldoca/astgen/test/variant/ast.generated.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "nlohmann/json.hpp" +#include "maldoca/base/status_macros.h" + +namespace maldoca { + +void MaybeAddComma(std::ostream &os, bool &needs_comma) { + if (needs_comma) { + os << ","; + } + needs_comma = true; +} + +// ============================================================================= +// VBaseClass +// ============================================================================= + +void VBaseClass::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +// ============================================================================= +// VDerivedClass1 +// ============================================================================= + +void VDerivedClass1::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +void VDerivedClass1::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"DerivedClass1\""; + VBaseClass::SerializeFields(os, needs_comma); + VDerivedClass1::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// VDerivedClass2 +// ============================================================================= + +void VDerivedClass2::SerializeFields(std::ostream& os, bool &needs_comma) const { +} + +void VDerivedClass2::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + MaybeAddComma(os, needs_comma); + os << "\"type\":\"DerivedClass2\""; + VBaseClass::SerializeFields(os, needs_comma); + VDerivedClass2::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// ============================================================================= +// VNode +// ============================================================================= + +void VNode::SerializeFields(std::ostream& os, bool &needs_comma) const { + MaybeAddComma(os, needs_comma); + switch (simple_variant_builtin_.index()) { + case 0: { + os << "\"simpleVariantBuiltin\":" << (nlohmann::json(std::get<0>(simple_variant_builtin_))).dump(); + break; + } + case 1: { + os << "\"simpleVariantBuiltin\":" << (nlohmann::json(std::get<1>(simple_variant_builtin_))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + MaybeAddComma(os, needs_comma); + if (nullable_variant_builtin_.has_value()) { + switch (nullable_variant_builtin_.value().index()) { + case 0: { + os << "\"nullableVariantBuiltin\":" << (nlohmann::json(std::get<0>(nullable_variant_builtin_.value()))).dump(); + break; + } + case 1: { + os << "\"nullableVariantBuiltin\":" << (nlohmann::json(std::get<1>(nullable_variant_builtin_.value()))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } else { + os << "\"nullableVariantBuiltin\":" << "null"; + } + if (optional_variant_builtin_.has_value()) { + MaybeAddComma(os, needs_comma); + switch (optional_variant_builtin_.value().index()) { + case 0: { + os << "\"optionalVariantBuiltin\":" << (nlohmann::json(std::get<0>(optional_variant_builtin_.value()))).dump(); + break; + } + case 1: { + os << "\"optionalVariantBuiltin\":" << (nlohmann::json(std::get<1>(optional_variant_builtin_.value()))).dump(); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + MaybeAddComma(os, needs_comma); + switch (simple_variant_class_.index()) { + case 0: { + os << "\"simpleVariantClass\":"; + std::get<0>(simple_variant_class_)->Serialize(os); + break; + } + case 1: { + os << "\"simpleVariantClass\":"; + std::get<1>(simple_variant_class_)->Serialize(os); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + MaybeAddComma(os, needs_comma); + if (nullable_variant_class_.has_value()) { + switch (nullable_variant_class_.value().index()) { + case 0: { + os << "\"nullableVariantClass\":"; + std::get<0>(nullable_variant_class_.value())->Serialize(os); + break; + } + case 1: { + os << "\"nullableVariantClass\":"; + std::get<1>(nullable_variant_class_.value())->Serialize(os); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } else { + os << "\"nullableVariantClass\":" << "null"; + } + if (optional_variant_class_.has_value()) { + MaybeAddComma(os, needs_comma); + switch (optional_variant_class_.value().index()) { + case 0: { + os << "\"optionalVariantClass\":"; + std::get<0>(optional_variant_class_.value())->Serialize(os); + break; + } + case 1: { + os << "\"optionalVariantClass\":"; + std::get<1>(optional_variant_class_.value())->Serialize(os); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } +} + +void VNode::Serialize(std::ostream& os) const { + os << "{"; + { + bool needs_comma = false; + VNode::SerializeFields(os, needs_comma); + } + os << "}"; +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/variant/ast_ts_interface.generated b/maldoca/astgen/test/variant/ast_ts_interface.generated new file mode 100644 index 0000000..9d96c1a --- /dev/null +++ b/maldoca/astgen/test/variant/ast_ts_interface.generated @@ -0,0 +1,17 @@ +interface BaseClass { +} + +interface DerivedClass1 <: BaseClass { +} + +interface DerivedClass2 <: BaseClass { +} + +interface Node { + simpleVariantBuiltin: /*double*/number | string + nullableVariantBuiltin: /*double*/number | string | null + optionalVariantBuiltin?: /*double*/number | string + simpleVariantClass: DerivedClass1 | DerivedClass2 + nullableVariantClass: DerivedClass1 | DerivedClass2 | null + optionalVariantClass?: DerivedClass1 | DerivedClass2 +} diff --git a/maldoca/astgen/test/variant/conversion/BUILD b/maldoca/astgen/test/variant/conversion/BUILD new file mode 100644 index 0000000..234f91a --- /dev/null +++ b/maldoca/astgen/test/variant/conversion/BUILD @@ -0,0 +1,77 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_applicable_licenses = ["//:license"]) + +licenses(["notice"]) + +exports_files([ + "ast_to_vir.generated.cc", + "vir_to_ast.generated.cc", +]) + +cc_library( + name = "ast_to_vir", + srcs = ["ast_to_vir.generated.cc"], + hdrs = ["ast_to_vir.h"], + deps = [ + "//maldoca/astgen/test/variant:ast", + "//maldoca/astgen/test/variant:ir", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "vir_to_ast", + srcs = ["vir_to_ast.generated.cc"], + hdrs = ["vir_to_ast.h"], + deps = [ + "//maldoca/astgen/test/variant:ast", + "//maldoca/astgen/test/variant:ir", + "//maldoca/base:status", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_test( + name = "conversion_test", + srcs = ["conversion_test.cc"], + deps = [ + ":ast_to_vir", + ":vir_to_ast", + "//maldoca/astgen/test:conversion_test_util", + "//maldoca/astgen/test/variant:ast", + "//maldoca/astgen/test/variant:ir", + "@googletest//:gtest_main", + ], +) diff --git a/maldoca/astgen/test/variant/conversion/ast_to_vir.generated.cc b/maldoca/astgen/test/variant/conversion/ast_to_vir.generated.cc new file mode 100644 index 0000000..5d78c9d --- /dev/null +++ b/maldoca/astgen/test/variant/conversion/ast_to_vir.generated.cc @@ -0,0 +1,160 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/variant/conversion/ast_to_vir.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/astgen/test/variant/ast.generated.h" +#include "maldoca/astgen/test/variant/ir.h" + +namespace maldoca { + +VirBaseClassOpInterface AstToVir::VisitBaseClass(const VBaseClass *node) { + if (auto *derived_class1 = dynamic_cast(node)) { + return VisitDerivedClass1(derived_class1); + } + if (auto *derived_class2 = dynamic_cast(node)) { + return VisitDerivedClass2(derived_class2); + } + LOG(FATAL) << "Unreachable code."; +} + +VirDerivedClass1Op AstToVir::VisitDerivedClass1(const VDerivedClass1 *node) { + return CreateExpr(node); +} + +VirDerivedClass2Op AstToVir::VisitDerivedClass2(const VDerivedClass2 *node) { + return CreateExpr(node); +} + +VirNodeOp AstToVir::VisitNode(const VNode *node) { + mlir::Attribute mlir_simple_variant_builtin; + switch (node->simple_variant_builtin().index()) { + case 0: { + mlir_simple_variant_builtin = builder_.getF64FloatAttr(std::get<0>(node->simple_variant_builtin())); + break; + } + case 1: { + mlir_simple_variant_builtin = builder_.getStringAttr(std::get<1>(node->simple_variant_builtin())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + mlir::Attribute mlir_nullable_variant_builtin; + if (node->nullable_variant_builtin().has_value()) { + switch (node->nullable_variant_builtin().value().index()) { + case 0: { + mlir_nullable_variant_builtin = builder_.getF64FloatAttr(std::get<0>(node->nullable_variant_builtin().value())); + break; + } + case 1: { + mlir_nullable_variant_builtin = builder_.getStringAttr(std::get<1>(node->nullable_variant_builtin().value())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + mlir::Attribute mlir_optional_variant_builtin; + if (node->optional_variant_builtin().has_value()) { + switch (node->optional_variant_builtin().value().index()) { + case 0: { + mlir_optional_variant_builtin = builder_.getF64FloatAttr(std::get<0>(node->optional_variant_builtin().value())); + break; + } + case 1: { + mlir_optional_variant_builtin = builder_.getStringAttr(std::get<1>(node->optional_variant_builtin().value())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + mlir::Value mlir_simple_variant_class; + switch (node->simple_variant_class().index()) { + case 0: { + mlir_simple_variant_class = VisitDerivedClass1(std::get<0>(node->simple_variant_class())); + break; + } + case 1: { + mlir_simple_variant_class = VisitDerivedClass2(std::get<1>(node->simple_variant_class())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + mlir::Value mlir_nullable_variant_class; + if (node->nullable_variant_class().has_value()) { + switch (node->nullable_variant_class().value().index()) { + case 0: { + mlir_nullable_variant_class = VisitDerivedClass1(std::get<0>(node->nullable_variant_class().value())); + break; + } + case 1: { + mlir_nullable_variant_class = VisitDerivedClass2(std::get<1>(node->nullable_variant_class().value())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + mlir::Value mlir_optional_variant_class; + if (node->optional_variant_class().has_value()) { + switch (node->optional_variant_class().value().index()) { + case 0: { + mlir_optional_variant_class = VisitDerivedClass1(std::get<0>(node->optional_variant_class().value())); + break; + } + case 1: { + mlir_optional_variant_class = VisitDerivedClass2(std::get<1>(node->optional_variant_class().value())); + break; + } + default: + LOG(FATAL) << "Unreachable code."; + } + } + return CreateExpr(node, mlir_simple_variant_builtin, mlir_nullable_variant_builtin, mlir_optional_variant_builtin, mlir_simple_variant_class, mlir_nullable_variant_class, mlir_optional_variant_class); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/variant/conversion/ast_to_vir.h b/maldoca/astgen/test/variant/conversion/ast_to_vir.h new file mode 100644 index 0000000..a295bbd --- /dev/null +++ b/maldoca/astgen/test/variant/conversion/ast_to_vir.h @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_VARIANT_CONVERSION_AST_TO_VIR_H_ +#define MALDOCA_ASTGEN_TEST_VARIANT_CONVERSION_AST_TO_VIR_H_ + +#include "mlir/IR/Builders.h" +#include "maldoca/astgen/test/variant/ast.generated.h" +#include "maldoca/astgen/test/variant/ir.h" + +namespace maldoca { + +class AstToVir { + public: + explicit AstToVir(mlir::OpBuilder &builder) : builder_(builder) {} + + VirBaseClassOpInterface VisitBaseClass(const VBaseClass *node); + + VirDerivedClass1Op VisitDerivedClass1(const VDerivedClass1 *node); + + VirDerivedClass2Op VisitDerivedClass2(const VDerivedClass2 *node); + + VirNodeOp VisitNode(const VNode *node); + + private: + template + Op CreateExpr(const VNode *node, Args &&...args) { + return builder_.create(builder_.getUnknownLoc(), + std::forward(args)...); + } + + mlir::OpBuilder &builder_; +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_VARIANT_CONVERSION_AST_TO_VIR_H_ diff --git a/maldoca/astgen/test/variant/conversion/conversion_test.cc b/maldoca/astgen/test/variant/conversion/conversion_test.cc new file mode 100644 index 0000000..037047b --- /dev/null +++ b/maldoca/astgen/test/variant/conversion/conversion_test.cc @@ -0,0 +1,123 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "maldoca/astgen/test/conversion_test_util.h" +#include "maldoca/astgen/test/variant/ast.generated.h" +#include "maldoca/astgen/test/variant/conversion/ast_to_vir.h" +#include "maldoca/astgen/test/variant/conversion/vir_to_ast.h" +#include "maldoca/astgen/test/variant/ir.h" + +namespace maldoca { +namespace { + +TEST(ConversionTest, Double) { + constexpr char kAstJsonString[] = R"( + { + "simpleVariantBuiltin": 1, + "nullableVariantBuiltin": 1, + "optionalVariantBuiltin": 1, + "simpleVariantClass": { + "type": "DerivedClass1" + }, + "nullableVariantClass": { + "type": "DerivedClass1" + }, + "optionalVariantClass": { + "type": "DerivedClass2" +} + } + )"; + + constexpr char kExpectedIr[] = R"( +module { + %0 = "vir.derived_class1"() : () -> !vir.any + %1 = "vir.derived_class1"() : () -> !vir.any + %2 = "vir.derived_class2"() : () -> !vir.any + %3 = "vir.node"(%0, %1, %2) <{nullable_variant_builtin = 1.000000e+00 : f64, operandSegmentSizes = array, optional_variant_builtin = 1.000000e+00 : f64, simple_variant_builtin = 1.000000e+00 : f64}> : (!vir.any, !vir.any, !vir.any) -> !vir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToVir::VisitNode, + .ir_to_ast_visit = &VirToAst::VisitNode, + .expected_ir_dump = kExpectedIr, + }); +} + +TEST(ConversionTest, String) { + constexpr char kAstJsonString[] = R"( + { + "simpleVariantBuiltin": "1", + "nullableVariantBuiltin": "1", + "optionalVariantBuiltin": "1", + "simpleVariantClass": { + "type": "DerivedClass1" + }, + "nullableVariantClass": { + "type": "DerivedClass1" + }, + "optionalVariantClass": { + "type": "DerivedClass2" + } + } + )"; + + constexpr char kExpectedIr[] = R"( +module { + %0 = "vir.derived_class1"() : () -> !vir.any + %1 = "vir.derived_class1"() : () -> !vir.any + %2 = "vir.derived_class2"() : () -> !vir.any + %3 = "vir.node"(%0, %1, %2) <{nullable_variant_builtin = "1", operandSegmentSizes = array, optional_variant_builtin = "1", simple_variant_builtin = "1"}> : (!vir.any, !vir.any, !vir.any) -> !vir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToVir::VisitNode, + .ir_to_ast_visit = &VirToAst::VisitNode, + .expected_ir_dump = kExpectedIr, + }); +} + +TEST(ConversionTest, Nullopt) { + constexpr char kAstJsonString[] = R"( + { + "simpleVariantBuiltin": "1", + "nullableVariantBuiltin": null, + "simpleVariantClass": { + "type": "DerivedClass1" + }, + "nullableVariantClass": null + } + )"; + + constexpr char kExpectedIr[] = R"( +module { + %0 = "vir.derived_class1"() : () -> !vir.any + %1 = "vir.node"(%0) <{operandSegmentSizes = array, simple_variant_builtin = "1"}> : (!vir.any) -> !vir.any +} + )"; + + TestIrConversion({ + .ast_json_string = kAstJsonString, + .ast_to_ir_visit = &AstToVir::VisitNode, + .ir_to_ast_visit = &VirToAst::VisitNode, + .expected_ir_dump = kExpectedIr, + }); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/astgen/test/variant/conversion/vir_to_ast.generated.cc b/maldoca/astgen/test/variant/conversion/vir_to_ast.generated.cc new file mode 100644 index 0000000..b0cac49 --- /dev/null +++ b/maldoca/astgen/test/variant/conversion/vir_to_ast.generated.cc @@ -0,0 +1,156 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +// IWYU pragma: begin_keep +// NOLINTBEGIN(whitespace/line_length) +// clang-format off + +#include "maldoca/astgen/test/variant/conversion/vir_to_ast.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "maldoca/base/status_macros.h" +#include "maldoca/astgen/test/variant/ast.generated.h" +#include "maldoca/astgen/test/variant/ir.h" + +namespace maldoca { + +absl::StatusOr> +VirToAst::VisitBaseClass(VirBaseClassOpInterface op) { + using Ret = absl::StatusOr>; + return llvm::TypeSwitch(op) + .Case([&](VirDerivedClass1Op op) { + return VisitDerivedClass1(op); + }) + .Case([&](VirDerivedClass2Op op) { + return VisitDerivedClass2(op); + }) + .Default([&](mlir::Operation* op) { + return absl::InvalidArgumentError("Unrecognized op"); + }); +} + +absl::StatusOr> +VirToAst::VisitDerivedClass1(VirDerivedClass1Op op) { + return Create( + op); +} + +absl::StatusOr> +VirToAst::VisitDerivedClass2(VirDerivedClass2Op op) { + return Create( + op); +} + +absl::StatusOr> +VirToAst::VisitNode(VirNodeOp op) { + std::variant simple_variant_builtin; + if (auto mlir_simple_variant_builtin = llvm::dyn_cast(op.getSimpleVariantBuiltinAttr())) { + simple_variant_builtin = mlir_simple_variant_builtin.getValueAsDouble(); + } else if (auto mlir_simple_variant_builtin = llvm::dyn_cast(op.getSimpleVariantBuiltinAttr())) { + simple_variant_builtin = mlir_simple_variant_builtin.str(); + } else { + return absl::InvalidArgumentError("op.getSimpleVariantBuiltinAttr() has invalid type."); + } + std::optional> nullable_variant_builtin; + if (op.getNullableVariantBuiltinAttr() != nullptr) { + if (auto mlir_nullable_variant_builtin = llvm::dyn_cast(op.getNullableVariantBuiltinAttr())) { + nullable_variant_builtin = mlir_nullable_variant_builtin.getValueAsDouble(); + } else if (auto mlir_nullable_variant_builtin = llvm::dyn_cast(op.getNullableVariantBuiltinAttr())) { + nullable_variant_builtin = mlir_nullable_variant_builtin.str(); + } else { + return absl::InvalidArgumentError("op.getNullableVariantBuiltinAttr() has invalid type."); + } + } + std::optional> optional_variant_builtin; + if (op.getOptionalVariantBuiltinAttr() != nullptr) { + if (auto mlir_optional_variant_builtin = llvm::dyn_cast(op.getOptionalVariantBuiltinAttr())) { + optional_variant_builtin = mlir_optional_variant_builtin.getValueAsDouble(); + } else if (auto mlir_optional_variant_builtin = llvm::dyn_cast(op.getOptionalVariantBuiltinAttr())) { + optional_variant_builtin = mlir_optional_variant_builtin.str(); + } else { + return absl::InvalidArgumentError("op.getOptionalVariantBuiltinAttr() has invalid type."); + } + } + std::variant, std::unique_ptr> simple_variant_class; + if (auto mlir_simple_variant_class = llvm::dyn_cast(op.getSimpleVariantClass().getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(simple_variant_class, VisitDerivedClass1(mlir_simple_variant_class)); + } else if (auto mlir_simple_variant_class = llvm::dyn_cast(op.getSimpleVariantClass().getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(simple_variant_class, VisitDerivedClass2(mlir_simple_variant_class)); + } else { + return absl::InvalidArgumentError("op.getSimpleVariantClass().getDefiningOp() has invalid type."); + } + std::optional, std::unique_ptr>> nullable_variant_class; + if (op.getNullableVariantClass() != nullptr) { + if (auto mlir_nullable_variant_class = llvm::dyn_cast(op.getNullableVariantClass().getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(nullable_variant_class, VisitDerivedClass1(mlir_nullable_variant_class)); + } else if (auto mlir_nullable_variant_class = llvm::dyn_cast(op.getNullableVariantClass().getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(nullable_variant_class, VisitDerivedClass2(mlir_nullable_variant_class)); + } else { + return absl::InvalidArgumentError("op.getNullableVariantClass().getDefiningOp() has invalid type."); + } + } + std::optional, std::unique_ptr>> optional_variant_class; + if (op.getOptionalVariantClass() != nullptr) { + if (auto mlir_optional_variant_class = llvm::dyn_cast(op.getOptionalVariantClass().getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(optional_variant_class, VisitDerivedClass1(mlir_optional_variant_class)); + } else if (auto mlir_optional_variant_class = llvm::dyn_cast(op.getOptionalVariantClass().getDefiningOp())) { + MALDOCA_ASSIGN_OR_RETURN(optional_variant_class, VisitDerivedClass2(mlir_optional_variant_class)); + } else { + return absl::InvalidArgumentError("op.getOptionalVariantClass().getDefiningOp() has invalid type."); + } + } + return Create( + op, + std::move(simple_variant_builtin), + std::move(nullable_variant_builtin), + std::move(optional_variant_builtin), + std::move(simple_variant_class), + std::move(nullable_variant_class), + std::move(optional_variant_class)); +} + +// clang-format on +// NOLINTEND(whitespace/line_length) +// IWYU pragma: end_keep + +} // namespace maldoca diff --git a/maldoca/astgen/test/variant/conversion/vir_to_ast.h b/maldoca/astgen/test/variant/conversion/vir_to_ast.h new file mode 100644 index 0000000..d98e17a --- /dev/null +++ b/maldoca/astgen/test/variant/conversion/vir_to_ast.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_VARIANT_CONVERSION_VIR_TO_AST_H_ +#define MALDOCA_ASTGEN_TEST_VARIANT_CONVERSION_VIR_TO_AST_H_ + +#include + +#include "mlir/IR/Operation.h" +#include "absl/status/statusor.h" +#include "maldoca/astgen/test/variant/ast.generated.h" +#include "maldoca/astgen/test/variant/ir.h" + +namespace maldoca { + +class VirToAst { + public: + absl::StatusOr> VisitBaseClass( + VirBaseClassOpInterface op); + + absl::StatusOr> VisitDerivedClass1( + VirDerivedClass1Op op); + + absl::StatusOr> VisitDerivedClass2( + VirDerivedClass2Op op); + + absl::StatusOr> VisitNode(VirNodeOp op); + + private: + template + std::unique_ptr Create(mlir::Operation *op, Args &&...args) { + return absl::make_unique(std::forward(args)...); + } +}; + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TEST_VARIANT_CONVERSION_VIR_TO_AST_H_ diff --git a/maldoca/astgen/test/variant/interfaces.td b/maldoca/astgen/test/variant/interfaces.td new file mode 100644 index 0000000..78f2afc --- /dev/null +++ b/maldoca/astgen/test/variant/interfaces.td @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "mlir/IR/OpBase.td" + +def VirBaseClassOpInterface : OpInterface<"VirBaseClassOpInterface"> { + let cppNamespace = "::maldoca"; + + let extraClassDeclaration = [{ + operator mlir::Value() { // NOLINT + return getOperation()->getResult(0); + } + }]; +} + +def VirBaseClassOpInterfaceTraits : TraitList<[ + DeclareOpInterfaceMethods +]>; diff --git a/maldoca/astgen/test/variant/ir.cc b/maldoca/astgen/test/variant/ir.cc new file mode 100644 index 0000000..1371849 --- /dev/null +++ b/maldoca/astgen/test/variant/ir.cc @@ -0,0 +1,63 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/test/variant/ir.h" + +// IWYU pragma: begin_keep + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" + +// IWYU pragma: end_keep + +// ============================================================================= +// Dialect Definition +// ============================================================================= + +#include "maldoca/astgen/test/variant/vir_dialect.cc.inc" + +/// Dialect initialization, the instance will be owned by the context. This is +/// the point of registration of types and operations for the dialect. +void maldoca::VirDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "maldoca/astgen/test/variant/vir_types.cc.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "maldoca/astgen/test/variant/vir_ops.generated.cc.inc" + >(); +} + +// ============================================================================= +// Dialect Interface Definitions +// ============================================================================= + +#include "maldoca/astgen/test/variant/interfaces.cc.inc" + +// ============================================================================= +// Dialect Type Definitions +// ============================================================================= + +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/variant/vir_types.cc.inc" + +// ============================================================================= +// Dialect Op Definitions +// ============================================================================= + +#define GET_OP_CLASSES +#include "maldoca/astgen/test/variant/vir_ops.generated.cc.inc" diff --git a/maldoca/astgen/test/variant/ir.h b/maldoca/astgen/test/variant/ir.h new file mode 100644 index 0000000..1fc6d0b --- /dev/null +++ b/maldoca/astgen/test/variant/ir.h @@ -0,0 +1,40 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_VARIANT_IR_H_ +#define MALDOCA_ASTGEN_TEST_VARIANT_IR_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +// Include the auto-generated header file containing the declaration of the VIR +// dialect. +#include "maldoca/astgen/test/variant/vir_dialect.h.inc" + +// Include the auto-generated header file containing the declarations of the VIR +// interfaces. +#include "maldoca/astgen/test/variant/interfaces.h.inc" + +// Include the auto-generated header file containing the declarations of the VIR +// types. +#define GET_TYPEDEF_CLASSES +#include "maldoca/astgen/test/variant/vir_types.h.inc" + +// Include the auto-generated header file containing the declarations of the VIR +// operations. +#define GET_OP_CLASSES +#include "maldoca/astgen/test/variant/vir_ops.generated.h.inc" + +#endif // MALDOCA_ASTGEN_TEST_VARIANT_IR_H_ diff --git a/maldoca/astgen/test/variant/vir_dialect.td b/maldoca/astgen/test/variant/vir_dialect.td new file mode 100644 index 0000000..516b479 --- /dev/null +++ b/maldoca/astgen/test/variant/vir_dialect.td @@ -0,0 +1,42 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_VARIANT_VIR_DIALECT_TD_ +#define MALDOCA_ASTGEN_TEST_VARIANT_VIR_DIALECT_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def Vir_Dialect : Dialect { + let name = "vir"; + let cppNamespace = "::maldoca"; + + let description = [{ + The VariantIR, a test IR that makes extensive use of variant types. All ops + and fields are directly mapped from the AST. + }]; + + let useDefaultTypePrinterParser = 1; +} + +class Vir_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = ?; +} + +class Vir_Op traits = []> : + Op; + +#endif // MALDOCA_ASTGEN_TEST_VARIANT_VIR_DIALECT_TD_ diff --git a/maldoca/astgen/test/variant/vir_ops.generated.td b/maldoca/astgen/test/variant/vir_ops.generated.td new file mode 100644 index 0000000..3864001 --- /dev/null +++ b/maldoca/astgen/test/variant/vir_ops.generated.td @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ============================================================================= +// STOP!! DO NOT MODIFY!! THIS FILE IS AUTOMATICALLY GENERATED. +// ============================================================================= + +#ifndef MALDOCA_ASTGEN_TEST_VARIANT_VIR_OPS_GENERATED_TD_ +#define MALDOCA_ASTGEN_TEST_VARIANT_VIR_OPS_GENERATED_TD_ + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" +include "maldoca/astgen/test/variant/interfaces.td" +include "maldoca/astgen/test/variant/vir_dialect.td" +include "maldoca/astgen/test/variant/vir_types.td" + +def VirDerivedClass1Op : Vir_Op< + "derived_class1", [ + VirBaseClassOpInterfaceTraits + ]> { + let results = (outs + VirAnyType + ); +} + +def VirDerivedClass2Op : Vir_Op< + "derived_class2", [ + VirBaseClassOpInterfaceTraits + ]> { + let results = (outs + VirAnyType + ); +} + +def VirNodeOp : Vir_Op< + "node", [ + AttrSizedOperandSegments + ]> { + let arguments = (ins + AnyAttrOf<[F64Attr, StrAttr]>: $simple_variant_builtin, + OptionalAttr>: $nullable_variant_builtin, + OptionalAttr>: $optional_variant_builtin, + AnyType: $simple_variant_class, + Optional: $nullable_variant_class, + Optional: $optional_variant_class + ); + + let results = (outs + VirAnyType + ); +} + +#endif // MALDOCA_ASTGEN_TEST_VARIANT_VIR_OPS_GENERATED_TD_ diff --git a/maldoca/astgen/test/variant/vir_types.td b/maldoca/astgen/test/variant/vir_types.td new file mode 100644 index 0000000..597d6b8 --- /dev/null +++ b/maldoca/astgen/test/variant/vir_types.td @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TEST_VARIANT_VIR_TYPES_TD_ +#define MALDOCA_ASTGEN_TEST_VARIANT_VIR_TYPES_TD_ + +include "maldoca/astgen/test/variant/vir_dialect.td" + +def VirAnyType : Vir_Type<"VirAny"> { + let summary = "A placeholder singleton type."; + let mnemonic = "any"; + let assemblyFormat = ""; +} + +#endif // MALDOCA_ASTGEN_TEST_VARIANT_VIR_TYPES_TD_ diff --git a/maldoca/astgen/type.cc b/maldoca/astgen/type.cc new file mode 100644 index 0000000..704c6ee --- /dev/null +++ b/maldoca/astgen/type.cc @@ -0,0 +1,652 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/type.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/symbol.h" +#include "maldoca/astgen/ast_def.h" +#include "maldoca/astgen/type.pb.h" +#include "maldoca/base/status_macros.h" + +namespace maldoca { +namespace { + +std::unique_ptr FromBoolTypePb(const BoolTypePb &pb) { + return absl::make_unique(BuiltinTypeKind::kBool, ""); +} + +std::unique_ptr FromInt64TypePb(const Int64TypePb &pb) { + return absl::make_unique(BuiltinTypeKind::kInt64, ""); +} + +std::unique_ptr FromDoubleTypePb(const DoubleTypePb &pb) { + return absl::make_unique(BuiltinTypeKind::kDouble, ""); +} + +std::unique_ptr FromStringTypePb(const StringTypePb &pb) { + return absl::make_unique(BuiltinTypeKind::kString, ""); +} + +std::unique_ptr FromEnumTypePb(absl::string_view enum_, + absl::string_view lang_name_) { + return absl::make_unique(Symbol(enum_), lang_name_); +} + +std::unique_ptr FromClassTypePb(absl::string_view class_, + absl::string_view lang_name_) { + return absl::make_unique(Symbol(class_), lang_name_); +} + +absl::StatusOr> FromVariantTypePb( + const VariantTypePb &pb, absl::string_view lang_name) { + std::vector> types; + for (const ScalarTypePb &type : pb.types()) { + switch (type.kind_case()) { + case ScalarTypePb::KindCase::KIND_NOT_SET: + return absl::InvalidArgumentError( + "Invalid variant element type: KIND_NOT_SET."); + + case ScalarTypePb::KindCase::kBool: + types.push_back(FromBoolTypePb(type.bool_())); + break; + + case ScalarTypePb::KindCase::kInt64: + types.push_back(FromInt64TypePb(type.int64())); + break; + + case ScalarTypePb::KindCase::kDouble: + types.push_back(FromDoubleTypePb(type.double_())); + break; + + case ScalarTypePb::KindCase::kString: + types.push_back(FromStringTypePb(type.string())); + break; + + case ScalarTypePb::KindCase::kEnum: + types.push_back(FromEnumTypePb(type.enum_(), lang_name)); + break; + + case ScalarTypePb::KindCase::kClass: + types.push_back(FromClassTypePb(type.class_(), lang_name)); + break; + } + } + + if (types.empty()) { + return absl::InvalidArgumentError("Empty variant type."); + } + + if (types.size() == 1) { + return absl::InvalidArgumentError("Variant with only one case."); + } + + return absl::make_unique(std::move(types), lang_name); +} + +absl::StatusOr> FromListTypePb( + const ListTypePb &pb, absl::string_view lang_name) { + std::unique_ptr element_type; + switch (pb.element_type().kind_case()) { + case NonListTypePb::KIND_NOT_SET: + return absl::InvalidArgumentError( + "Invalid list element type: KIND_NOT_SET."); + + case NonListTypePb::KindCase::kBool: + element_type = FromBoolTypePb(pb.element_type().bool_()); + break; + + case NonListTypePb::KindCase::kInt64: + element_type = FromInt64TypePb(pb.element_type().int64()); + break; + + case NonListTypePb::KindCase::kDouble: + element_type = FromDoubleTypePb(pb.element_type().double_()); + break; + + case NonListTypePb::KindCase::kString: + element_type = FromStringTypePb(pb.element_type().string()); + break; + + case NonListTypePb::KindCase::kEnum: + element_type = FromEnumTypePb(pb.element_type().enum_(), lang_name); + break; + + case NonListTypePb::KindCase::kClass: + element_type = FromClassTypePb(pb.element_type().class_(), lang_name); + break; + + case NonListTypePb::kVariant: { + MALDOCA_ASSIGN_OR_RETURN( + element_type, + FromVariantTypePb(pb.element_type().variant(), lang_name)); + break; + } + } + + return absl::make_unique( + std::move(element_type), + pb.element_maybe_null() ? MaybeNull::kYes : MaybeNull::kNo, lang_name); +} + +} // namespace + +absl::StatusOr> FromTypePb(const TypePb &pb, + absl::string_view lang_name) { + switch (pb.kind_case()) { + case TypePb::KindCase::KIND_NOT_SET: + return absl::InvalidArgumentError("Invalid TypePb: KIND_NOT_SET."); + + case TypePb::KindCase::kBool: + return FromBoolTypePb(pb.bool_()); + + case TypePb::KindCase::kInt64: + return FromInt64TypePb(pb.int64()); + + case TypePb::KindCase::kDouble: + return FromDoubleTypePb(pb.double_()); + + case TypePb::KindCase::kString: + return FromStringTypePb(pb.string()); + + case TypePb::KindCase::kEnum: + return FromEnumTypePb(pb.enum_(), lang_name); + + case TypePb::KindCase::kClass: + return FromClassTypePb(pb.class_(), lang_name); + + case TypePb::KindCase::kVariant: + return FromVariantTypePb(pb.variant(), lang_name); + + case TypePb::KindCase::kList: + return FromListTypePb(pb.list(), lang_name); + } +} + +// ============================================================================= +// JsType() +// ============================================================================= + +std::string Type::JsType(MaybeNull maybe_null) const { + std::string str = JsType(); + switch (maybe_null) { + case MaybeNull::kYes: + return absl::StrCat(std::move(str), " | null"); + case MaybeNull::kNo: + return str; + } +} + +std::string ListType::JsType() const { + return absl::StrCat("[ ", element_type().JsType(element_maybe_null()), " ]"); +} + +std::string VariantType::JsType() const { + std::vector type_strings; + for (const auto &type : types()) { + type_strings.push_back(type->JsType()); + } + return absl::StrJoin(type_strings, " | "); +} + +std::string BuiltinType::JsType() const { + switch (builtin_kind()) { + case BuiltinTypeKind::kBool: + return "boolean"; + case BuiltinTypeKind::kInt64: + return "/*int64*/number"; + case BuiltinTypeKind::kDouble: + return "/*double*/number"; + case BuiltinTypeKind::kString: + return "string"; + } +} + +std::string EnumType::JsType() const { return name().ToPascalCase(); } + +std::string ClassType::JsType() const { return name().ToPascalCase(); } + +// ============================================================================= +// CcType() +// ============================================================================= + +std::string Type::CcType(MaybeNull maybe_null) const { + switch (maybe_null) { + case MaybeNull::kYes: + return CcType(OPTIONALNESS_MAYBE_NULL); + case MaybeNull::kNo: + return CcType(); + } +} + +std::string Type::CcType(Optionalness optionalness) const { + std::string str = CcType(); + switch (optionalness) { + case OPTIONALNESS_MAYBE_NULL: + case OPTIONALNESS_MAYBE_UNDEFINED: + return absl::StrCat("std::optional<", std::move(str), ">"); + default: + return str; + } +} + +std::string ListType::CcType() const { + return absl::StrCat("std::vector<", + element_type().CcType(element_maybe_null()), ">"); +} + +std::string VariantType::CcType() const { + std::vector type_strings; + for (const auto &type : types()) { + type_strings.push_back(type->CcType()); + } + return absl::StrCat("std::variant<", absl::StrJoin(type_strings, ", "), ">"); +} + +std::string BuiltinType::CcType() const { + switch (builtin_kind()) { + case BuiltinTypeKind::kBool: + return "bool"; + case BuiltinTypeKind::kInt64: + return "int64_t"; + case BuiltinTypeKind::kDouble: + return "double"; + case BuiltinTypeKind::kString: + return "std::string"; + } +} + +std::string EnumType::CcType() const { + return (Symbol(lang_name_) + name()).ToPascalCase(); +} + +std::string ClassType::CcType() const { + return absl::StrCat("std::unique_ptr<", + (Symbol(lang_name_) + name()).ToPascalCase(), ">"); +} + +// ============================================================================= +// CcGetterType() +// ============================================================================= + +std::string Type::CcMutableGetterType() const { + return CcGetterType(CcGetterKind::kMutable); +} + +std::string Type::CcMutableGetterType(MaybeNull maybe_null) const { + return CcGetterType(CcGetterKind::kMutable, maybe_null); +} + +std::string Type::CcMutableGetterType(Optionalness optionalness) const { + return CcGetterType(CcGetterKind::kMutable, optionalness); +} + +std::string Type::CcConstGetterType() const { + return CcGetterType(CcGetterKind::kConst); +} + +std::string Type::CcConstGetterType(MaybeNull maybe_null) const { + return CcGetterType(CcGetterKind::kConst, maybe_null); +} + +std::string Type::CcConstGetterType(Optionalness optionalness) const { + return CcGetterType(CcGetterKind::kConst, optionalness); +} + +std::string Type::CcGetterType(CcGetterKind getter_kind, + MaybeNull maybe_null) const { + switch (maybe_null) { + case MaybeNull::kYes: + return CcGetterType(getter_kind, OPTIONALNESS_MAYBE_NULL); + case MaybeNull::kNo: + return CcGetterType(getter_kind, OPTIONALNESS_REQUIRED); + } +} + +std::string Type::CcGetterType(CcGetterKind getter_kind, + Optionalness optionalness) const { + std::string str = CcGetterType(getter_kind); + switch (optionalness) { + case OPTIONALNESS_MAYBE_NULL: + case OPTIONALNESS_MAYBE_UNDEFINED: + return absl::StrCat("std::optional<", std::move(str), ">"); + default: + return str; + } +} + +std::string ListType::CcGetterType(CcGetterKind getter_kind) const { + switch (getter_kind) { + case CcGetterKind::kMutable: + return absl::StrCat(CcType(), "*"); + case CcGetterKind::kConst: + return absl::StrCat("const ", CcType(), "*"); + } +} + +std::string VariantType::CcGetterType(CcGetterKind getter_kind) const { + std::vector type_strings; + for (const auto &type : types()) { + type_strings.push_back(type->CcGetterType(getter_kind)); + } + return absl::StrCat("std::variant<", absl::StrJoin(type_strings, ", "), ">"); +} + +std::string BuiltinType::CcGetterType(CcGetterKind getter_kind) const { + switch (builtin_kind()) { + case BuiltinTypeKind::kBool: + return "bool"; + case BuiltinTypeKind::kInt64: + return "int64_t"; + case BuiltinTypeKind::kDouble: + return "double"; + case BuiltinTypeKind::kString: + return "absl::string_view"; + } +} + +std::string EnumType::CcGetterType(CcGetterKind getter_kind) const { + return (Symbol(lang_name_) + name()).ToPascalCase(); +} + +std::string ClassType::CcGetterType(CcGetterKind getter_kind) const { + const auto class_name = (Symbol(lang_name_) + name()).ToPascalCase(); + switch (getter_kind) { + case CcGetterKind::kMutable: + return absl::StrCat(class_name, "*"); + case CcGetterKind::kConst: + return absl::StrCat("const ", class_name, "*"); + } +} + +// ============================================================================= +// CcMlirBuilderType() / CcMlirGetterType() +// ============================================================================= + +std::string ListType::CcMlirBuilderType(FieldKind kind) const { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + return "mlir::ArrayAttr"; + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + return "std::vector"; + case FIELD_KIND_STMT: + LOG(FATAL) << "List of statements not supported."; + } +} + +std::string ListType::CcMlirGetterType(FieldKind kind) const { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + return "mlir::ArrayAttr"; + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + return "mlir::OperandRange"; + case FIELD_KIND_STMT: + LOG(FATAL) << "List of statements not supported."; + } +} + +std::string VariantType::CcMlirType(FieldKind kind) const { + absl::flat_hash_set cc_mlir_types; + for (const auto &type : types()) { + cc_mlir_types.insert(type->CcMlirType(kind)); + } + + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: { + if (cc_mlir_types.size() == 1) { + return *cc_mlir_types.begin(); + } + return "mlir::Attribute"; + } + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: { + CHECK_EQ(cc_mlir_types.size(), 1); + return *cc_mlir_types.begin(); + } + case FIELD_KIND_STMT: + LOG(FATAL) << "Variant of statements not supported."; + } +} + +std::string BuiltinType::CcMlirType(FieldKind kind) const { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + break; + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + case FIELD_KIND_STMT: + LOG(FATAL) << "Invalid FieldKind: " << kind; + } + + switch (builtin_kind()) { + case BuiltinTypeKind::kBool: + return "mlir::BoolAttr"; + case BuiltinTypeKind::kInt64: + return "mlir::IntegerAttr"; + case BuiltinTypeKind::kDouble: + return "mlir::FloatAttr"; + case BuiltinTypeKind::kString: + return "mlir::StringAttr"; + } +} + +std::string EnumType::CcMlirType(FieldKind kind) const { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + break; + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + case FIELD_KIND_STMT: + LOG(FATAL) << "Invalid FieldKind: " << kind; + } + + return "mlir::StringAttr"; +} + +std::string ClassType::CcMlirType(FieldKind kind) const { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: { + if (node_def_ != nullptr) { + auto ir_op_name = node_def_->ir_op_name(lang_name_, kind); + if (ir_op_name.has_value()) { + return ir_op_name->ToPascalCase(); + } + } + + auto ir_name = Symbol(absl::StrCat(lang_name_, "ir")); + return (ir_name + name() + "Attr").ToPascalCase(); + } + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + return "mlir::Value"; + case FIELD_KIND_STMT: + LOG(FATAL) << "Invalid FieldKind: " << kind; + } +} + +// ============================================================================= +// TdType() +// ============================================================================= + +std::string Type::TdType(MaybeNull maybe_null, FieldKind kind) const { + switch (maybe_null) { + case MaybeNull::kNo: + return TdType(kind); + + case MaybeNull::kYes: { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + return absl::StrCat("OptionalAttr<", TdType(kind), ">"); + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + return absl::StrCat("Optional<", TdType(kind), ">"); + case FIELD_KIND_STMT: + LOG(FATAL) << "Statement fields are not supported."; + } + } + } +} + +std::string Type::TdType(Optionalness optionalness, FieldKind kind) const { + switch (optionalness) { + case OPTIONALNESS_UNSPECIFIED: + case OPTIONALNESS_REQUIRED: + return TdType(MaybeNull::kNo, kind); + case OPTIONALNESS_MAYBE_NULL: + case OPTIONALNESS_MAYBE_UNDEFINED: + return TdType(MaybeNull::kYes, kind); + } +} + +std::string ListType::TdType(FieldKind kind) const { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: { + auto element_td_type = element_type().TdType(element_maybe_null(), kind); + return absl::StrCat("TypedArrayAttrBase<", element_td_type, ", \"\">"); + } + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: { + // TODO(b/204592400) Variadic> is not supported. + auto element_td_type = element_type().TdType(kind); + return absl::StrCat("Variadic<", element_td_type, ">"); + } + case FIELD_KIND_STMT: + LOG(FATAL) << "Statement fields are not supported."; + } +} + +std::string VariantType::TdType(FieldKind kind) const { + std::vector type_kinds; + for (const auto &type : types()) { + type_kinds.push_back(type->kind()); + } + + auto VariantAttrTdType = [&] { + std::vector td_types; + for (const auto &type : types()) { + td_types.push_back(type->TdType(kind)); + } + + return absl::StrCat("AnyAttrOf<[", absl::StrJoin(td_types, ", "), "]>"); + }; + + // Variant of builtin types. + if (absl::c_all_of(type_kinds, absl::bind_front(std::equal_to(), + TypeKind::kBuiltin))) { + return VariantAttrTdType(); + } + + // Variant of class types. + if (absl::c_all_of(type_kinds, absl::bind_front(std::equal_to(), + TypeKind::kClass))) { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: + return VariantAttrTdType(); + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + return "AnyType"; + case FIELD_KIND_STMT: + LOG(FATAL) << "Statement fields are not supported."; + } + } + + LOG(FATAL) << "We only support variants of builtin types or variants of " + "class types."; +} + +std::string BuiltinType::TdType(FieldKind kind) const { + CHECK_EQ(kind, FIELD_KIND_ATTR) + << "Invalid FieldKind for builtin type: " << kind; + + switch (builtin_kind()) { + case BuiltinTypeKind::kBool: + return "BoolAttr"; + case BuiltinTypeKind::kInt64: + return "I64Attr"; + case BuiltinTypeKind::kDouble: + return "F64Attr"; + case BuiltinTypeKind::kString: + return "StrAttr"; + } +} + +std::string EnumType::TdType(FieldKind kind) const { + CHECK_EQ(kind, FIELD_KIND_ATTR) + << "Invalid FieldKind for enum type: " << kind; + + // TODO(b/182441574): Properly support enums. + return "StrAttr"; +} + +std::string ClassType::TdType(FieldKind kind) const { + switch (kind) { + case FIELD_KIND_UNSPECIFIED: + LOG(FATAL) << "Unspecified FieldKind."; + case FIELD_KIND_ATTR: { + if (node_def_ != nullptr) { + auto ir_op_name = node_def_->ir_op_name(lang_name_, kind); + if (ir_op_name.has_value()) { + return ir_op_name->ToPascalCase(); + } + } + return (Symbol(lang_name_ + "ir") + name() + "Attr").ToPascalCase(); + } + case FIELD_KIND_LVAL: + case FIELD_KIND_RVAL: + break; + case FIELD_KIND_STMT: + LOG(FATAL) << "Statement fields are not supported."; + } + + return "AnyType"; +} + +} // namespace maldoca diff --git a/maldoca/astgen/type.h b/maldoca/astgen/type.h new file mode 100644 index 0000000..d01ef7d --- /dev/null +++ b/maldoca/astgen/type.h @@ -0,0 +1,577 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MALDOCA_ASTGEN_TYPE_H_ +#define MALDOCA_ASTGEN_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/symbol.h" +#include "maldoca/astgen/type.pb.h" + +namespace maldoca { + +class NodeDef; + +// The Type Hierarchy +// +// Type ::= NonListType, ListType +// NonListType ::= ScalarType, VariantType +// ScalarType ::= BuiltinType, ClassType +// BuiltinType ::= BoolType, DoubleType, StringType +// +// Type +// | +// +-----------+-----------+ +// | | +// NonListType | +// | | +// +----------+----------+ | +// | | | +// ScalarType | | +// | | | +// +------+-------+ | | +// | | | | +// BuiltinType ClassType VariantType ListType + +enum class MaybeNull { + kNo, + kYes, +}; + +enum class TypeKind { + kBuiltin, + kEnum, + kClass, + kVariant, + kList, +}; + +class Type { + public: + virtual ~Type() = default; + + // Check if a Type is a specific type T. + // + // Usage: + // + // const Type &type = ...; + // if (type.IsA()) { + // const auto &non_list_type = static_cast(type); + // ... + // } + // + // This is mimicking the LLVM-style RTTI. + // https://llvm.org/docs/ProgrammersManual.html + // + // Each class T that inherits Type needs to have a IsTheClassOf() static + // method that checks if a Type is a T. + template + bool IsA() const { + static_assert(std::is_base_of_v, "T is not a subclass of Type."); + return T::IsTheClassOf(*this); + } + + static bool IsTheClassOf(const Type &type) { return true; } + + TypeKind kind() const { return kind_; } + + absl::string_view lang_name() const { return lang_name_; } + + // Prints TypeScript type annotations. + // + // Types that are maybe_null are printed as variants. + // E.g. "bool" with "maybe_null = true" ==> "bool | null". + // E.g. "bool | string" with "maybe_null = true" ==> "bool | string | null". + virtual std::string JsType() const = 0; + std::string JsType(MaybeNull maybe_null) const; + + // Prints the C++ type for storing the field. + // + // Types that are maybe_null or maybe_undefined are printed as + // "std::optional". + // + // bool + // => bool + // + // double + // => double + // + // string + // => std::string + // + // ClassType + // => std::unique_ptr + // + // Class1 | Class2 + // => std::variant, std::unique_ptr> + // + // [ClassType] + // => std::vector> + // + // ClassType with maybe_null or maybe_undefined + // => std::optional> + virtual std::string CcType() const = 0; + std::string CcType(MaybeNull maybe_null) const; + std::string CcType(Optionalness optionalness) const; + + // Prints the C++ return type for the getter function. + // + // Types that are maybe_null or maybe_undefined are printed as + // "std::optional". + // + // bool + // => bool + // + // double + // => double + // + // string + // => std::string + // + // ClassType + // => ClassType* + // + // Class1 | Class2 + // => std::variant + // + // [ClassType] + // => std::vector>* + // + // ClassType with maybe_null or maybe_undefined + // => std::optional + std::string CcMutableGetterType() const; + std::string CcMutableGetterType(MaybeNull maybe_null) const; + std::string CcMutableGetterType(Optionalness optionalness) const; + + // Prints the C++ return type for the const getter function. + // + // bool + // => bool + // + // double + // => double + // + // string + // => absl::string_view + // + // ClassType + // => const ClassType* + // + // Class1 | Class2 + // => std::variant + // + // [ClassType] + // => const std::vector>* + // + // ClassType with maybe_null or maybe_undefined + // => std::optional + std::string CcConstGetterType(MaybeNull maybe_null) const; + std::string CcConstGetterType(Optionalness optionalness) const; + std::string CcConstGetterType() const; + + // Common functions that handle both CcMutableGetterType() and + // CcConstGetterType(). + enum class CcGetterKind { + kMutable, + kConst, + }; + virtual std::string CcGetterType(CcGetterKind getter_kind) const = 0; + std::string CcGetterType(CcGetterKind getter_kind, + MaybeNull maybe_null) const; + std::string CcGetterType(CcGetterKind getter_kind, + Optionalness optionalness) const; + + // Prints the C++ type for MLIR builders. + // + // - maybe_null/optionalness: + // Whether to qualify the type with optional (Type by itself is + // non-optional). + // + // - kind: + // Each field in an AST node has a kind. Different kinds lead to different + // ops in the IR. See detailed explanations in ast_def.proto, and concrete + // listings below. + // + // Builtin type: kind must be FIELD_KIND_ATTR. + // bool => mlir::BoolAttr + // int64 => mlir::IntegerAttr + // double => mlir::FloatAttr + // string => mlir::StringAttr + // + // Builtin type with maybe_null or maybe_undefined: + // Same as above. MLIR attributes can be nullptr. + // + // ClassType with kind == FIELD_KIND_LVAL or FIELD_KIND_RVAL: + // => mlir::Value + // + // ClassType with kind == FIELD_KIND_ATTR: + // => IrNameClassTypeAttr + // + // ClassType with maybe_null or maybe_undefined + // => Same as above. MLIR values can be nullptr. + // + // Class1 | Class2: kind must be FIELD_KIND_LVAL or FIELD_KIND_RVAL. + // => mlir::Value + // + // Builtin1 | Builtin2: kind must be FIELD_KIND_ATTR. + // => mlir::Attribute + // + // Builtin1 | Builtin2 with maybe_null or maybe_undefined: + // Same as above. MLIR attributes can be nullptr. + // + // [ClassType] with kind == FIELD_KIND_LVAL or FIELD_KIND_RVAL: + // => std::vector + // + // [ClassType] with kind == FIELD_KIND_ATTR: + // => std::vector + // + // [Builtin] + // => mlir::ArrayAttr + virtual std::string CcMlirBuilderType(FieldKind kind) const = 0; + + // Prints the C++ type for MLIR getters. + // + // - maybe_null/optionalness: + // Whether to qualify the type with optional (Type by itself is + // non-optional). + // + // - kind: + // Each field in an AST node has a kind. Different kinds lead to different + // ops in the IR. See detailed explanations in ast_def.proto, and concrete + // listings below. + // + // Builtin type: kind must be FIELD_KIND_ATTR. + // bool => mlir::BoolAttr + // int64 => mlir::IntegerAttr + // double => mlir::FloatAttr + // string => mlir::StringAttr + // + // Builtin type with maybe_null or maybe_undefined: + // Same as above. MLIR attributes can be nullptr. + // + // ClassType with kind == FIELD_KIND_LVAL or FIELD_KIND_RVAL: + // => mlir::Value + // + // ClassType with kind == FIELD_KIND_ATTR: + // => IrNameClassTypeAttr + // + // ClassType with maybe_null or maybe_undefined + // => Same as above. MLIR values can be nullptr. + // + // Class1 | Class2: kind must be FIELD_KIND_LVAL or FIELD_KIND_RVAL. + // => mlir::Value + // + // Builtin1 | Builtin2: kind must be FIELD_KIND_ATTR. + // => mlir::Attribute + // + // Builtin1 | Builtin2 with maybe_null or maybe_undefined: + // Same as above. MLIR attributes can be nullptr. + // + // [ClassType] with kind == FIELD_KIND_LVAL or FIELD_KIND_RVAL: + // => mlir::OperandRange + // + // [ClassType] with kind == FIELD_KIND_ATTR: + // => mlir::OperandRange + // + // [Builtin] + // => mlir::ArrayAttr + virtual std::string CcMlirGetterType(FieldKind kind) const = 0; + + // Prints the MLIR TableGen type. + // + // - maybe_null/optionalness: + // Whether to qualify the type with optional (Type by itself is + // non-optional). + // + // - kind: + // Each field in an AST node has a kind. Different kinds lead to different + // ops in the IR. See detailed explanations in ast_def.proto. + // Currently the only difference here is that for an attribute, we use + // OptionalAttr<...>; otherwise, we use Optional<...>. + // + // Builtin type: kind must be FIELD_KIND_ATTR. + // bool => BoolAttr + // int64 => I64Attr + // double => F64Attr + // string => StrAttr + // + // Builtin type with maybe_null or maybe_undefined: + // OptionalAttr<...> + // + // ClassType: kind must be FIELD_KIND_LVAL or FIELD_KIND_RVAL. + // => AnyType + // + // Class1 | Class2: kind must be FIELD_KIND_LVAL or FIELD_KIND_RVAL. + // => AnyType + // + // Builtin1 | Builtin2: kind must be FIELD_KIND_ATTR. + // => AnyAttrOf + // + // Builtin1 | Builtin2 with maybe_null or maybe_undefined: + // => OptionalAttr> + // + // ClassType with maybe_null or maybe_undefined + // => Optional + // + // [ClassType] + // => Variadic + // + // Currently, maybe_null and maybe_undefined are not supported for list types + // and list element types. + std::string TdType(MaybeNull maybe_null, FieldKind kind) const; + std::string TdType(Optionalness optionalness, FieldKind kind) const; + virtual std::string TdType(FieldKind kind) const = 0; + + protected: + explicit Type(TypeKind kind, absl::string_view lang_name) + : lang_name_(lang_name), kind_(kind) {} + std::string lang_name_; + + private: + const TypeKind kind_; + friend class AstDef; +}; + +class NonListType : public Type { + public: + static bool IsTheClassOf(const Type &type) { + return type.kind() == TypeKind::kBuiltin || + type.kind() == TypeKind::kEnum || type.kind() == TypeKind::kClass || + type.kind() == TypeKind::kVariant; + } + + // For `NonListType`, `CcMlirGetterType` and `CcMlirBuilderType` are the same. + // For the definitions of `CcMlirGetterType` and `CcMlirBuilderType`, see + // comments for class `Type`. + virtual std::string CcMlirType(FieldKind kind) const = 0; + + std::string CcMlirBuilderType(FieldKind kind) const final { + return CcMlirType(kind); + } + + std::string CcMlirGetterType(FieldKind kind) const final { + return CcMlirType(kind); + } + + protected: + explicit NonListType(TypeKind kind, absl::string_view lang_name) + : Type(kind, lang_name) {} +}; + +// ListType { +// element_type: NonListType +// element_maybe_null: bool +// } +// +// We explicitly don't allow nested lists, so the element type of a list must be +// non-list. +class ListType : public Type { + public: + explicit ListType(std::unique_ptr element_type, + MaybeNull element_maybe_null, absl::string_view lang_name) + : Type(TypeKind::kList, lang_name), + element_type_(std::move(element_type)), + element_maybe_null_(element_maybe_null) {} + + static bool IsTheClassOf(const Type &type) { + return type.kind() == TypeKind::kList; + } + + const NonListType &element_type() const { return *element_type_; } + NonListType &element_type() { return *element_type_; } + + MaybeNull element_maybe_null() const { return element_maybe_null_; } + + std::string JsType() const override; + + std::string CcType() const override; + + std::string CcGetterType(CcGetterKind getter_kind) const override; + + std::string CcMlirBuilderType(FieldKind kind) const override; + + std::string CcMlirGetterType(FieldKind kind) const override; + + std::string TdType(FieldKind kind) const override; + + private: + std::unique_ptr element_type_; + MaybeNull element_maybe_null_; +}; + +// Scalar type: non-variant and non-list. +class ScalarType : public NonListType { + public: + static bool IsTheClassOf(const Type &type) { + return type.kind() == TypeKind::kBuiltin || + type.kind() == TypeKind::kEnum || type.kind() == TypeKind::kClass; + } + + protected: + explicit ScalarType(TypeKind kind, absl::string_view lang_name) + : NonListType(kind, lang_name) {} +}; + +// VariantType { +// types: [ScalarType] +// } +// +// We explicitly limit the types a variant can hold to be scalar. In other +// words, we don't allow nested variants or lists in variants. +class VariantType : public NonListType { + public: + explicit VariantType(std::vector> types, + absl::string_view lang_name) + : NonListType(TypeKind::kVariant, lang_name), types_(std::move(types)) {} + + static bool IsTheClassOf(const Type &type) { + return type.kind() == TypeKind::kVariant; + } + + absl::Span> types() const { return types_; } + + absl::Span> types() { + return absl::MakeSpan(types_); + } + + std::string JsType() const override; + + std::string CcType() const override; + + std::string CcGetterType(CcGetterKind getter_kind) const override; + + std::string CcMlirType(FieldKind kind) const final; + + std::string TdType(FieldKind kind) const override; + + private: + std::vector> types_; +}; + +enum class BuiltinTypeKind { + kBool, + kInt64, + kDouble, + kString, +}; + +class BuiltinType : public ScalarType { + public: + explicit BuiltinType(BuiltinTypeKind builtin_kind, + absl::string_view lang_name) + : ScalarType(TypeKind::kBuiltin, lang_name), + builtin_kind_(builtin_kind) {} + + static bool IsTheClassOf(const Type &type) { + return type.kind() == TypeKind::kBuiltin; + } + + BuiltinTypeKind builtin_kind() const { return builtin_kind_; } + + std::string JsType() const override; + + std::string CcType() const override; + + std::string CcGetterType(CcGetterKind getter_kind) const override; + + std::string CcMlirType(FieldKind kind) const final; + + std::string TdType(FieldKind kind) const override; + + private: + BuiltinTypeKind builtin_kind_; +}; + +// Represents an enum type defined elsewhere. +class EnumType : public ScalarType { + public: + explicit EnumType(const Symbol &name, absl::string_view lang_name) + : ScalarType(TypeKind::kEnum, lang_name), name_(name) {} + + static bool IsTheClassOf(const Type &type) { + return type.kind() == TypeKind::kEnum; + } + + const Symbol &name() const { return name_; } + + std::string JsType() const override; + + std::string CcType() const override; + + std::string CcGetterType(CcGetterKind getter_kind) const override; + + std::string CcMlirType(FieldKind kind) const final; + + std::string TdType(FieldKind kind) const override; + + private: + Symbol name_; +}; + +// ClassType { +// name: Symbol +// } +// +// Represents an AST node type defined elsewhere. +class ClassType : public ScalarType { + public: + explicit ClassType(const Symbol &name, absl::string_view lang_name) + : ScalarType(TypeKind::kClass, lang_name), name_(name) {} + + static bool IsTheClassOf(const Type &type) { + return type.kind() == TypeKind::kClass; + } + + const Symbol &name() const { return name_; } + + std::string JsType() const override; + + std::string CcType() const override; + + std::string CcGetterType(CcGetterKind getter_kind) const override; + + std::string CcMlirType(FieldKind kind) const final; + + std::string TdType(FieldKind kind) const override; + + private: + Symbol name_; + const NodeDef *absl_nullable node_def_ = nullptr; + + friend class AstDef; +}; + +// Converts from TypePb to Type. +absl::StatusOr> FromTypePb(const TypePb &pb, + absl::string_view lang_name); + +} // namespace maldoca + +#endif // MALDOCA_ASTGEN_TYPE_H_ diff --git a/maldoca/astgen/type.proto b/maldoca/astgen/type.proto new file mode 100644 index 0000000..c467727 --- /dev/null +++ b/maldoca/astgen/type.proto @@ -0,0 +1,104 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package maldoca; + +option java_multiple_files = true; + + +// The Type Hierarchy +// +// BuiltinType ::= BoolType, DoubleType, StringType +// ScalarType ::= BuiltinType, ClassType +// NonListType ::= ScalarType, VariantType +// Type ::= NonListType, ListType +// +// Type +// | +// +-----------+-----------+ +// | | +// NonListType | +// | | +// +----------+----------+ | +// | | | +// ScalarType | | +// | | | +// +------+-------+ | | +// | | | | +// BuiltinType ClassType VariantType ListType +// +// NOTE: For each non-leaf type, we list all the corresponding leaf types in the +// oneof. This is to make textproto files shorter. + +message TypePb { + oneof kind { + BoolTypePb bool = 1; + Int64TypePb int64 = 2; + DoubleTypePb double = 3; + StringTypePb string = 4; + string enum = 5; + string class = 6; + VariantTypePb variant = 7; + ListTypePb list = 8; + } +} + +message ScalarTypePb { + oneof kind { + BoolTypePb bool = 1; + Int64TypePb int64 = 2; + DoubleTypePb double = 3; + StringTypePb string = 4; + string enum = 5; + string class = 6; + } +} + +message NonListTypePb { + oneof kind { + BoolTypePb bool = 1; + Int64TypePb int64 = 2; + DoubleTypePb double = 3; + StringTypePb string = 4; + string enum = 5; + string class = 6; + VariantTypePb variant = 7; + } +} + +// Builtin types. +// +// NOTE: We choose to define individual empty messages instead of a enum, for +// simpler textprotos. +// +// In particular, we would be able to write: +// "type { bool {} }" +// instead of: +// "type { builtin: BUILTIN_BOOL }". + +message BoolTypePb {} +message Int64TypePb {} +message DoubleTypePb {} +message StringTypePb {} + +message VariantTypePb { + repeated ScalarTypePb types = 1; +} + +message ListTypePb { + optional NonListTypePb element_type = 1; + optional bool element_maybe_null = 2; +} diff --git a/maldoca/astgen/type_test.cc b/maldoca/astgen/type_test.cc new file mode 100644 index 0000000..6d5bcd4 --- /dev/null +++ b/maldoca/astgen/type_test.cc @@ -0,0 +1,490 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "maldoca/astgen/type.h" + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "maldoca/astgen/ast_def.pb.h" +#include "maldoca/astgen/type.pb.h" +#include "maldoca/base/testing/status_matchers.h" +#include "maldoca/base/filesystem.h" + +namespace maldoca { +namespace { + +struct TypeTestCase { + const char *type_pb; + const char *js_type; + const char *cc_type; + const char *cc_getter_type; + const char *cc_const_getter_type; + const char *cc_lang_name; + absl::flat_hash_map td_types; + absl::flat_hash_map cc_mlir_builder_type; + absl::flat_hash_map cc_mlir_getter_type; +}; + +void TestTypePbToTypeAndPrint(TypeTestCase type_test_case) { + TypePb type_pb; + MALDOCA_ASSERT_OK(ParseTextProto(type_test_case.type_pb, + "type_test_case.type_pb", &type_pb)); + MALDOCA_ASSERT_OK_AND_ASSIGN( + std::unique_ptr type, + FromTypePb(type_pb, type_test_case.cc_lang_name != nullptr + ? type_test_case.cc_lang_name + : "")); + EXPECT_EQ(type->JsType(), type_test_case.js_type); + EXPECT_EQ(type->CcType(), type_test_case.cc_type); + EXPECT_EQ(type->CcMutableGetterType(), type_test_case.cc_getter_type); + EXPECT_EQ(type->CcConstGetterType(), type_test_case.cc_const_getter_type); + for (const auto &pair : type_test_case.td_types) { + // We don't allow C++17 in the codebase for compatibility reasons, so we + // cannot use structured binding. + FieldKind field_kind; + std::string td_type; + std::tie(field_kind, td_type) = pair; + + EXPECT_EQ(type->TdType(field_kind), td_type); + } + for (const auto &pair : type_test_case.cc_mlir_builder_type) { + // We don't allow C++17 in the codebase for compatibility reasons, so we + // cannot use structured binding. + FieldKind field_kind; + std::string cc_mlir_builder_type; + std::tie(field_kind, cc_mlir_builder_type) = pair; + + EXPECT_EQ(type->CcMlirBuilderType(field_kind), cc_mlir_builder_type); + } + for (const auto &pair : type_test_case.cc_mlir_getter_type) { + // We don't allow C++17 in the codebase for compatibility reasons, so we + // cannot use structured binding. + // FieldKind field_kind; + // std::string cc_mlir_getter_type; + // std::tie(field_kind, cc_mlir_getter_type) = pair; + auto [field_kind, cc_mlir_getter_type] = pair; + + EXPECT_EQ(type->CcMlirGetterType(field_kind), cc_mlir_getter_type); + } +} + +TEST(TypeTest, ConvertBuiltinType) { + TestTypePbToTypeAndPrint({ + .type_pb = "bool {}", + .js_type = "boolean", + .cc_type = "bool", + .cc_getter_type = "bool", + .cc_const_getter_type = "bool", + .td_types = + { + {FIELD_KIND_ATTR, "BoolAttr"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_ATTR, "mlir::BoolAttr"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_ATTR, "mlir::BoolAttr"}, + }, + }); + + TestTypePbToTypeAndPrint({ + .type_pb = "int64 {}", + .js_type = "/*int64*/number", + .cc_type = "int64_t", + .cc_getter_type = "int64_t", + .cc_const_getter_type = "int64_t", + .td_types = + { + {FIELD_KIND_ATTR, "I64Attr"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_ATTR, "mlir::IntegerAttr"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_ATTR, "mlir::IntegerAttr"}, + }, + }); + + TestTypePbToTypeAndPrint({ + .type_pb = "double {}", + .js_type = "/*double*/number", + .cc_type = "double", + .cc_getter_type = "double", + .cc_const_getter_type = "double", + .td_types = + { + {FIELD_KIND_ATTR, "F64Attr"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_ATTR, "mlir::FloatAttr"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_ATTR, "mlir::FloatAttr"}, + }, + }); + + TestTypePbToTypeAndPrint({ + .type_pb = "string {}", + .js_type = "string", + .cc_type = "std::string", + .cc_getter_type = "absl::string_view", + .cc_const_getter_type = "absl::string_view", + .td_types = + { + {FIELD_KIND_ATTR, "StrAttr"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_ATTR, "mlir::StringAttr"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_ATTR, "mlir::StringAttr"}, + }, + }); +} + +TEST(TypeTest, ConvertEnumType) { + TestTypePbToTypeAndPrint({ + .type_pb = R"pb(enum: "BinaryOperator")pb", + .js_type = "BinaryOperator", + .cc_type = "TestLangNameBinaryOperator", + .cc_getter_type = "TestLangNameBinaryOperator", + .cc_const_getter_type = "TestLangNameBinaryOperator", + .cc_lang_name = "TestLangName", + .td_types = + { + {FIELD_KIND_ATTR, "StrAttr"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_ATTR, "mlir::StringAttr"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_ATTR, "mlir::StringAttr"}, + }, + }); +} + +TEST(TypeTest, ConvertClassType) { + TestTypePbToTypeAndPrint({ + .type_pb = R"pb(class: "BinaryExpression")pb", + .js_type = "BinaryExpression", + .cc_type = "std::unique_ptr", + .cc_getter_type = "TestLangNameBinaryExpression*", + .cc_const_getter_type = "const TestLangNameBinaryExpression*", + .cc_lang_name = "TestLangName", + .td_types = + { + {FIELD_KIND_RVAL, "AnyType"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_RVAL, "mlir::Value"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_RVAL, "mlir::Value"}, + }, + }); +} + +TEST(TypeTest, ConvertVariantType) { + TestTypePbToTypeAndPrint({ + .type_pb = R"pb( + variant { + types { bool {} } + types { string {} } + } + )pb", + .js_type = "boolean | string", + .cc_type = "std::variant", + .cc_getter_type = "std::variant", + .cc_const_getter_type = "std::variant", + .td_types = + { + {FIELD_KIND_ATTR, "AnyAttrOf<[BoolAttr, StrAttr]>"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_ATTR, "mlir::Attribute"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_ATTR, "mlir::Attribute"}, + }, + }); + + TestTypePbToTypeAndPrint({ + .type_pb = R"pb( + variant { + types { class: "Expression" } + types { class: "Pattern" } + } + )pb", + .js_type = "Expression | Pattern", + .cc_type = "std::variant, " + "std::unique_ptr>", + .cc_getter_type = + "std::variant", + .cc_const_getter_type = "std::variant", + .cc_lang_name = "TestLangName", + .td_types = + { + {FIELD_KIND_RVAL, "AnyType"}, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_RVAL, "mlir::Value"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_RVAL, "mlir::Value"}, + }, + }); +} + +TEST(TypeTest, ConvertListType) { + TestTypePbToTypeAndPrint({ + .type_pb = R"pb( + list { element_type { class: "Expression" } } + )pb", + .js_type = "[ Expression ]", + .cc_type = "std::vector>", + .cc_getter_type = "std::vector>*", + .cc_const_getter_type = + "const std::vector>*", + .cc_lang_name = "TestLangName", + .td_types = + { + { + {FIELD_KIND_RVAL, "Variadic"}, + }, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_RVAL, "std::vector"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_RVAL, "mlir::OperandRange"}, + }, + }); + + TestTypePbToTypeAndPrint({ + .type_pb = R"pb( + list { + element_type { class: "Expression" } + element_maybe_null: true + } + )pb", + .js_type = "[ Expression | null ]", + .cc_type = "std::vector>>", + .cc_getter_type = "std::vector>>*", + .cc_const_getter_type = "const " + "std::vector>>*", + .cc_lang_name = "TestLangName", + .td_types = + { + { + {FIELD_KIND_RVAL, "Variadic"}, + }, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_RVAL, "std::vector"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_RVAL, "mlir::OperandRange"}, + }, + }); + + TestTypePbToTypeAndPrint({ + .type_pb = R"pb( + list { + element_type { + variant { + types { class: "Expression" } + types { class: "Pattern" } + } + } + element_maybe_null: true + } + )pb", + .js_type = "[ Expression | Pattern | null ]", + .cc_type = + "std::vector, std::unique_ptr>>>", + .cc_getter_type = "std::vector, " + "std::unique_ptr>>>*", + .cc_const_getter_type = "const " + "std::vector" + ", std::unique_ptr>>>*", + .cc_lang_name = "TestLangName", + .td_types = + { + { + {FIELD_KIND_RVAL, "Variadic"}, + }, + }, + .cc_mlir_builder_type = + { + {FIELD_KIND_RVAL, "std::vector"}, + }, + .cc_mlir_getter_type = + { + {FIELD_KIND_RVAL, "mlir::OperandRange"}, + }, + }); +} + +TEST(TypeTest, IsABuiltinType) { + TypePb type_pb; + MALDOCA_ASSERT_OK(maldoca::ParseTextProto(R"pb( + bool {} + )pb", "TypePb for IsABuiltinType", &type_pb)); + + MALDOCA_ASSERT_OK_AND_ASSIGN(std::unique_ptr type, + FromTypePb(type_pb, "TestLangName")); + + EXPECT_TRUE(type->IsA()); + + EXPECT_TRUE(type->IsA()); + EXPECT_TRUE(type->IsA()); + EXPECT_TRUE(type->IsA()); + + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); +} + +TEST(TypeTest, IsAClassType) { + TypePb type_pb; + MALDOCA_ASSERT_OK(maldoca::ParseTextProto(R"pb( + class: "Expression" + )pb", "TypePb for IsAClassType", &type_pb)); + + MALDOCA_ASSERT_OK_AND_ASSIGN(std::unique_ptr type, + FromTypePb(type_pb, "TestLangName")); + + EXPECT_TRUE(type->IsA()); + + EXPECT_TRUE(type->IsA()); + EXPECT_TRUE(type->IsA()); + EXPECT_TRUE(type->IsA()); + + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); +} + +TEST(TypeTest, IsAVariantType) { + TypePb type_pb; + MALDOCA_ASSERT_OK(maldoca::ParseTextProto(R"pb( + variant { + types { bool {} } + types { string {} } + } + )pb", + "TypePb for IsAVariantType", + &type_pb)); + + MALDOCA_ASSERT_OK_AND_ASSIGN(std::unique_ptr type, + FromTypePb(type_pb, "TestLangName")); + + EXPECT_TRUE(type->IsA()); + + EXPECT_TRUE(type->IsA()); + EXPECT_TRUE(type->IsA()); + + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); +} + +TEST(TypeTest, IsAListType) { + TypePb type_pb; + MALDOCA_ASSERT_OK(maldoca::ParseTextProto(R"pb( + list { element_type { bool {} } } + )pb", + "TypePb for IsAListType", + &type_pb)); + + MALDOCA_ASSERT_OK_AND_ASSIGN(std::unique_ptr type, + FromTypePb(type_pb, "TestLangName")); + + EXPECT_TRUE(type->IsA()); + + EXPECT_TRUE(type->IsA()); + + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); + EXPECT_FALSE(type->IsA()); +} + +TEST(TypeTest, EmptyTypeIsInvalid) { + TypePb type_pb; + MALDOCA_ASSERT_OK(maldoca::ParseTextProto("", "empty TypePb", &type_pb)); + EXPECT_THAT(FromTypePb(type_pb, "TestLangName"), + maldoca::testing::StatusIs(absl::StatusCode::kInvalidArgument, + "Invalid TypePb: KIND_NOT_SET.")); + + MALDOCA_ASSERT_OK(maldoca::ParseTextProto(R"pb( + variant {} + )pb", "empty variant", &type_pb)); + EXPECT_THAT(FromTypePb(type_pb, "TestLangName"), + maldoca::testing::StatusIs(absl::StatusCode::kInvalidArgument, + "Empty variant type.")); + + MALDOCA_ASSERT_OK(maldoca::ParseTextProto(R"pb( + variant { types {} } + )pb", "variant with empty type", &type_pb)); + EXPECT_THAT(FromTypePb(type_pb, "TestLangName"), + maldoca::testing::StatusIs( + absl::StatusCode::kInvalidArgument, + "Invalid variant element type: KIND_NOT_SET.")); + + MALDOCA_ASSERT_OK(maldoca::ParseTextProto(R"pb( + list { element_type {} } + )pb", "list with empty element type", &type_pb)); + EXPECT_THAT( + FromTypePb(type_pb, "TestLangName"), + maldoca::testing::StatusIs(absl::StatusCode::kInvalidArgument, + "Invalid list element type: KIND_NOT_SET.")); +} + +} // namespace +} // namespace maldoca diff --git a/maldoca/js/ast/ast_from_json.generated.cc b/maldoca/js/ast/ast_from_json.generated.cc index 8350e86..effea1a 100644 --- a/maldoca/js/ast/ast_from_json.generated.cc +++ b/maldoca/js/ast/ast_from_json.generated.cc @@ -28,15 +28,15 @@ #include #include +#include "maldoca/js/ast/ast.generated.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "nlohmann/json.hpp" #include "maldoca/base/status_macros.h" -#include "maldoca/js/ast/ast.generated.h" +#include "nlohmann/json.hpp" namespace maldoca { diff --git a/maldoca/js/ast/ast_to_json.generated.cc b/maldoca/js/ast/ast_to_json.generated.cc index 28a08fd..f32ddfd 100644 --- a/maldoca/js/ast/ast_to_json.generated.cc +++ b/maldoca/js/ast/ast_to_json.generated.cc @@ -26,13 +26,13 @@ #include #include +#include "maldoca/js/ast/ast.generated.h" #include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "nlohmann/json.hpp" #include "maldoca/base/status_macros.h" -#include "maldoca/js/ast/ast.generated.h" namespace maldoca { diff --git a/maldoca/js/driver/driver.proto b/maldoca/js/driver/driver.proto index 76d208e..072c2f4 100644 --- a/maldoca/js/driver/driver.proto +++ b/maldoca/js/driver/driver.proto @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/maldoca/js/ir/BUILD b/maldoca/js/ir/BUILD index ea23ebf..dbc2321 100644 --- a/maldoca/js/ir/BUILD +++ b/maldoca/js/ir/BUILD @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -363,12 +363,7 @@ cc_library( "//maldoca/js/babel:babel_cc_proto", "//maldoca/js/driver", "//maldoca/js/driver:driver_cc_proto", - "//maldoca/js/ir/analyses:analysis", - "//maldoca/js/ir/analyses:conditional_forward_dataflow_analysis", - "//maldoca/js/ir/analyses:dataflow_analysis", - "//maldoca/js/ir/analyses/constant_propagation:analysis", "//maldoca/js/ir/conversion:utils", - "//maldoca/js/ir/transforms:transform", "//maldoca/js/quickjs_babel", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log", diff --git a/maldoca/js/ir/analyses/BUILD b/maldoca/js/ir/analyses/BUILD index afd3d8f..d49decc 100644 --- a/maldoca/js/ir/analyses/BUILD +++ b/maldoca/js/ir/analyses/BUILD @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/maldoca/js/ir/analyses/analysis.cc b/maldoca/js/ir/analyses/analysis.cc index 20f585f..bb8596d 100644 --- a/maldoca/js/ir/analyses/analysis.cc +++ b/maldoca/js/ir/analyses/analysis.cc @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/maldoca/js/ir/analyses/conditional_forward_dataflow_analysis.h b/maldoca/js/ir/analyses/conditional_forward_dataflow_analysis.h index 0afc998..0729cda 100644 --- a/maldoca/js/ir/analyses/conditional_forward_dataflow_analysis.h +++ b/maldoca/js/ir/analyses/conditional_forward_dataflow_analysis.h @@ -41,7 +41,6 @@ class JsirConditionalForwardDataFlowAnalysis : public JsirForwardDataFlowAnalysis { public: using Base = JsirForwardDataFlowAnalysis; - using DenseBase = JsirDenseForwardDataFlowAnalysis; explicit JsirConditionalForwardDataFlowAnalysis(mlir::DataFlowSolver &solver) : JsirForwardDataFlowAnalysis(solver) {} @@ -88,10 +87,10 @@ JsirConditionalForwardDataFlowAnalysis::GetIsExecutable( template void JsirConditionalForwardDataFlowAnalysis::VisitOp( mlir::Operation *op) { - JsirStateRef before_state_ref = DenseBase::GetStateBefore(op); + JsirStateRef before_state_ref = Base::GetStateBefore(op); const StateT *before = &before_state_ref.value(); - JsirStateRef after_state_ref = DenseBase::GetStateAfter(op); + JsirStateRef after_state_ref = Base::GetStateAfter(op); auto [operands, result_state_refs] = Base::GetValueStateRefs(op); @@ -120,6 +119,8 @@ void JsirConditionalForwardDataFlowAnalysis::VisitOp( llvm::isa(op) || llvm::isa(op) || llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op) || llvm::isa(op)) { return; } diff --git a/maldoca/js/ir/analyses/constant_propagation/analysis.cc b/maldoca/js/ir/analyses/constant_propagation/analysis.cc index 8bb0607..948891d 100644 --- a/maldoca/js/ir/analyses/constant_propagation/analysis.cc +++ b/maldoca/js/ir/analyses/constant_propagation/analysis.cc @@ -270,30 +270,32 @@ bool JsirConstantPropagationAnalysis::IsCfgEdgeExecutable( return true; } - auto [liveness_source, liveness_kind] = edge->getLivenessInfo().value(); - auto liveness_source_state_ref = GetStateAt(liveness_source); - if (liveness_source_state_ref.value().IsUninitialized()) { + auto [lhs_value, rhs, liveness_kind] = edge->getLivenessInfo().value(); + auto lhs_state_ref = GetStateAt(lhs_value); + if (lhs_state_ref.value().IsUninitialized()) { return false; - } else if (liveness_source_state_ref.value().IsUnknown()) { + } else if (lhs_state_ref.value().IsUnknown()) { return true; } - mlir::Attribute true_attr = - mlir::BoolAttr::get(liveness_source.getContext(), true); - mlir::Attribute false_attr = - mlir::BoolAttr::get(liveness_source.getContext(), false); - mlir::Attribute null_attr = - JsirNullLiteralAttr::get(liveness_source.getContext()); + if (auto rhs_attr = llvm::dyn_cast(rhs); + rhs_attr != nullptr) { + switch (liveness_kind) { + case LivenessKind::kLiveIfEqualOrUnknown: + return *lhs_state_ref.value() == rhs_attr; + case LivenessKind::kLiveIfNotEqualOrUnknown: + return *lhs_state_ref.value() != rhs_attr; + } + } + auto rhs_value = llvm::dyn_cast(rhs); + auto rhs_state_ref = GetStateAt(rhs_value); + std::optional rhs_attr = *rhs_state_ref.value(); switch (liveness_kind) { - case LivenessKind::kLiveIfTrueOrUnknown: - return *liveness_source_state_ref.value() == true_attr; - case LivenessKind::kLiveIfFalseOrUnknown: - return *liveness_source_state_ref.value() == false_attr; - case LivenessKind::kLiveIfNullOrUnknown: - return *liveness_source_state_ref.value() == null_attr; - case LivenessKind::kLiveIfNonNullOrUnknown: - return *liveness_source_state_ref.value() != null_attr; + case LivenessKind::kLiveIfEqualOrUnknown: + return *lhs_state_ref.value() == rhs_attr; + case LivenessKind::kLiveIfNotEqualOrUnknown: + return *lhs_state_ref.value() != rhs_attr; } } diff --git a/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/README.generated.md b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/README.generated.md new file mode 100644 index 0000000..b455aed --- /dev/null +++ b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/README.generated.md @@ -0,0 +1,8 @@ +To run manually: + +```shell +bazel run //maldoca/js/ir:jsir_gen -- \ + --input_file $(pwd)/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/input.js \ + --passes "source2ast,ast2hir" \ + --jsir_analysis constant_propagation +``` diff --git a/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/input.js b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/input.js new file mode 100644 index 0000000..356a349 --- /dev/null +++ b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/input.js @@ -0,0 +1,13 @@ +var x = 2; +switch (x) { + case 1: + break; + case 2: + case 3: + x = 4; + break; + case 4: + x = 5; + break; +}; +console.log(x); diff --git a/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/output.generated.txt b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/output.generated.txt new file mode 100644 index 0000000..b277b03 --- /dev/null +++ b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/output.generated.txt @@ -0,0 +1,199 @@ +// JSHIR: "jsir.file"() <{comments = []}> ({ +// JSHIR-NEXT: "jsir.program"() <{source_type = "script"}> ({ +// JSHIR-NEXT: "jsir.variable_declaration"() <{kind = "var"}> ({ +// JSHIR-NEXT: %5 = "jsir.identifier_ref"() <{name = "x"}> : () -> !jsir.any +// JSHIR-NEXT: %6 = "jsir.numeric_literal"() <{extra = #jsir, value = 2.000000e+00 : f64}> : () -> !jsir.any +// JSHIR-NEXT: %7 = "jsir.variable_declarator"(%5, %6) : (!jsir.any, !jsir.any) -> !jsir.any +// JSHIR-NEXT: "jsir.exprs_region_end"(%7) : (!jsir.any) -> () +// JSHIR-NEXT: }) : () -> () +// JSHIR-NEXT: %0 = "jsir.identifier"() <{name = "x"}> : () -> !jsir.any +// JSHIR-NEXT: "jshir.switch_statement"(%0) ({ +// JSHIR-NEXT: "jshir.switch_case"() ({ +// JSHIR-NEXT: %5 = "jsir.numeric_literal"() <{extra = #jsir, value = 1.000000e+00 : f64}> : () -> !jsir.any +// JSHIR-NEXT: "jsir.expr_region_end"(%5) : (!jsir.any) -> () +// JSHIR-NEXT: }, { +// JSHIR-NEXT: "jshir.break_statement"() : () -> () +// JSHIR-NEXT: }) : () -> () +// JSHIR-NEXT: "jshir.switch_case"() ({ +// JSHIR-NEXT: %5 = "jsir.numeric_literal"() <{extra = #jsir, value = 2.000000e+00 : f64}> : () -> !jsir.any +// JSHIR-NEXT: "jsir.expr_region_end"(%5) : (!jsir.any) -> () +// JSHIR-NEXT: }, { +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: }) : () -> () +// JSHIR-NEXT: "jshir.switch_case"() ({ +// JSHIR-NEXT: %5 = "jsir.numeric_literal"() <{extra = #jsir, value = 3.000000e+00 : f64}> : () -> !jsir.any +// JSHIR-NEXT: "jsir.expr_region_end"(%5) : (!jsir.any) -> () +// JSHIR-NEXT: }, { +// JSHIR-NEXT: %5 = "jsir.identifier_ref"() <{name = "x"}> : () -> !jsir.any +// JSHIR-NEXT: %6 = "jsir.numeric_literal"() <{extra = #jsir, value = 4.000000e+00 : f64}> : () -> !jsir.any +// JSHIR-NEXT: %7 = "jsir.assignment_expression"(%5, %6) <{operator_ = "="}> : (!jsir.any, !jsir.any) -> !jsir.any +// JSHIR-NEXT: "jsir.expression_statement"(%7) : (!jsir.any) -> () +// JSHIR-NEXT: "jshir.break_statement"() : () -> () +// JSHIR-NEXT: }) : () -> () +// JSHIR-NEXT: "jshir.switch_case"() ({ +// JSHIR-NEXT: %5 = "jsir.numeric_literal"() <{extra = #jsir, value = 4.000000e+00 : f64}> : () -> !jsir.any +// JSHIR-NEXT: "jsir.expr_region_end"(%5) : (!jsir.any) -> () +// JSHIR-NEXT: }, { +// JSHIR-NEXT: %5 = "jsir.identifier_ref"() <{name = "x"}> : () -> !jsir.any +// JSHIR-NEXT: %6 = "jsir.numeric_literal"() <{extra = #jsir, value = 5.000000e+00 : f64}> : () -> !jsir.any +// JSHIR-NEXT: %7 = "jsir.assignment_expression"(%5, %6) <{operator_ = "="}> : (!jsir.any, !jsir.any) -> !jsir.any +// JSHIR-NEXT: "jsir.expression_statement"(%7) : (!jsir.any) -> () +// JSHIR-NEXT: "jshir.break_statement"() : () -> () +// JSHIR-NEXT: }) : () -> () +// JSHIR-NEXT: }) : (!jsir.any) -> () +// JSHIR-NEXT: "jsir.empty_statement"() : () -> () +// JSHIR-NEXT: %1 = "jsir.identifier"() <{name = "console"}> : () -> !jsir.any +// JSHIR-NEXT: %2 = "jsir.member_expression"(%1) <{literal_property = #jsir, , "log", 130, 133, 0, "log">}> : (!jsir.any) -> !jsir.any +// JSHIR-NEXT: %3 = "jsir.identifier"() <{name = "x"}> : () -> !jsir.any +// JSHIR-NEXT: %4 = "jsir.call_expression"(%2, %3) : (!jsir.any, !jsir.any) -> !jsir.any +// JSHIR-NEXT: "jsir.expression_statement"(%4) : (!jsir.any) -> () +// JSHIR-NEXT: }, { +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: }) : () -> () +// JSHIR-NEXT: }) : () -> () +// JSHIR-NEXT: jsir.file {[]} ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.program {"script"} ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.variable_declaration {"var"} ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %5 = jsir.identifier_ref {"x"} +// JSHIR-NEXT: // %5 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %6 = jsir.numeric_literal {#jsir, 2.000000e+00 : f64} +// JSHIR-NEXT: // %6 = 2.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %7 = jsir.variable_declarator (%5, %6) +// JSHIR-NEXT: // %7 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.exprs_region_end (%7) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %0 = jsir.identifier {"x"} +// JSHIR-NEXT: // %0 = 2.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.switch_statement (%0) ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.switch_case ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %5 = jsir.numeric_literal {#jsir, 1.000000e+00 : f64} +// JSHIR-NEXT: // %5 = 1.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.expr_region_end (%5) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }, { +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.break_statement +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.switch_case ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %5 = jsir.numeric_literal {#jsir, 2.000000e+00 : f64} +// JSHIR-NEXT: // %5 = 2.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.expr_region_end (%5) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }, { +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.switch_case ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %5 = jsir.numeric_literal {#jsir, 3.000000e+00 : f64} +// JSHIR-NEXT: // %5 = 3.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.expr_region_end (%5) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }, { +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %5 = jsir.identifier_ref {"x"} +// JSHIR-NEXT: // %5 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %6 = jsir.numeric_literal {#jsir, 4.000000e+00 : f64} +// JSHIR-NEXT: // %6 = 4.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %7 = jsir.assignment_expression (%5, %6) {"="} +// JSHIR-NEXT: // %7 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.expression_statement (%7) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.break_statement +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.switch_case ({ +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %5 = jsir.numeric_literal {#jsir, 4.000000e+00 : f64} +// JSHIR-NEXT: // %5 = 4.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.expr_region_end (%5) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }, { +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %5 = jsir.identifier_ref {"x"} +// JSHIR-NEXT: // %5 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %6 = jsir.numeric_literal {#jsir, 5.000000e+00 : f64} +// JSHIR-NEXT: // %6 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %7 = jsir.assignment_expression (%5, %6) {"="} +// JSHIR-NEXT: // %7 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.expression_statement (%7) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jshir.break_statement +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.empty_statement +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %1 = jsir.identifier {"console"} +// JSHIR-NEXT: // %1 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %2 = jsir.member_expression (%1) {#jsir, , "log", 130, 133, 0, "log">} +// JSHIR-NEXT: // %2 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %3 = jsir.identifier {"x"} +// JSHIR-NEXT: // %3 = 4.000000e+00 : f64 +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: %4 = jsir.call_expression (%2, %3) +// JSHIR-NEXT: // %4 = +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: jsir.expression_statement (%4) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }, { +// JSHIR-NEXT: ^bb0: +// JSHIR-NEXT: // +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } +// JSHIR-NEXT: }) +// JSHIR-NEXT: // State [default = ] { } diff --git a/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/run.generated.lit b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/run.generated.lit new file mode 100644 index 0000000..f120aee --- /dev/null +++ b/maldoca/js/ir/analyses/constant_propagation/tests/switch_jshir/run.generated.lit @@ -0,0 +1,5 @@ +// RUN: CURRENT_FILE_BASENAME=$(basename %s .lit) && \ +// RUN: jsir_gen --input_file "$(dirname %s)"/input.js \ +// RUN: --passes "source2ast,ast2hir" \ +// RUN: --jsir_analysis constant_propagation \ +// RUN: | FileCheck --check-prefix JSHIR "$(dirname %s)"/output.generated.txt diff --git a/maldoca/js/ir/analyses/dataflow_analysis.h b/maldoca/js/ir/analyses/dataflow_analysis.h index c7b1e7d..d5b5d71 100644 --- a/maldoca/js/ir/analyses/dataflow_analysis.h +++ b/maldoca/js/ir/analyses/dataflow_analysis.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "llvm/ADT/ArrayRef.h" @@ -29,11 +30,14 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/AsmState.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -83,12 +87,11 @@ class JsirStateElement : public mlir::AnalysisState { } // namespace detail -enum class LivenessKind { - kLiveIfTrueOrUnknown, - kLiveIfFalseOrUnknown, - kLiveIfNullOrUnknown, - kLiveIfNonNullOrUnknown -}; +enum class LivenessKind { kLiveIfEqualOrUnknown, kLiveIfNotEqualOrUnknown }; + +using LivenessInfo = + std::tuple, + LivenessKind>; class JsirGeneralCfgEdge : public mlir::GenericLatticeAnchorBase< @@ -96,7 +99,7 @@ class JsirGeneralCfgEdge std::tuple, mlir::SmallVector, - std::optional>>> { + std::optional>> { public: using Base::Base; @@ -112,7 +115,7 @@ class JsirGeneralCfgEdge return std::get<3>(getValue()); } - std::optional> getLivenessInfo() const { + std::optional getLivenessInfo() const { return std::get<4>(getValue()); } @@ -128,18 +131,12 @@ class JsirGeneralCfgEdge os << getSuccValues().size(); if (getLivenessInfo().has_value()) { os << "\n liveness kind: "; - switch (std::get<1>(getLivenessInfo().value())) { - case LivenessKind::kLiveIfTrueOrUnknown: - os << "LiveIfTrueOrUnknown"; - break; - case LivenessKind::kLiveIfFalseOrUnknown: - os << "LiveIfFalseOrUnknown"; - break; - case LivenessKind::kLiveIfNullOrUnknown: - os << "LiveIfNullOrUnknown"; + switch (std::get<2>(getLivenessInfo().value())) { + case LivenessKind::kLiveIfEqualOrUnknown: + os << "LiveIfEqualOrUnknown"; break; - case LivenessKind::kLiveIfNonNullOrUnknown: - os << "LiveIfNonNullOrUnknown"; + case LivenessKind::kLiveIfNotEqualOrUnknown: + os << "LiveIfNotEqualOrUnknown"; break; } } @@ -280,16 +277,17 @@ class JsirDataFlowAnalysisPrinter { enum class DataflowDirection { kForward, kBackward }; // ============================================================================= -// JsirDenseDataFlowAnalysis +// JsirDataFlowAnalysis // ============================================================================= -// A dataflow analysis API that attaches lattices to operations. This analysis -// supports both forward and backward analysis. -template -class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, - public JsirDataFlowAnalysisPrinter, - public JsirDenseStates> { +// A dataflow analysis API that attaches lattices to both values and operations. +// This analysis supports both forward and backward analysis. +template +class JsirDataFlowAnalysis : public mlir::DataFlowAnalysis, + public JsirDataFlowAnalysisPrinter, + public JsirDenseStates>, + public JsirSparseStates> { public: - explicit JsirDenseDataFlowAnalysis(mlir::DataFlowSolver &solver) + explicit JsirDataFlowAnalysis(mlir::DataFlowSolver &solver) : mlir::DataFlowAnalysis(solver), solver_(solver) { registerAnchorKind(); } @@ -297,28 +295,25 @@ class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, // Set the initial state of an entry block for forward analysis or exit block // for backward analysis. virtual void InitializeBoundaryBlock(mlir::Block *block, - JsirStateRef boundary_state) = 0; + JsirStateRef boundary_state) { + std::vector> arg_states; + for (mlir::Value arg : block->getArguments()) { + arg_states.push_back(GetStateAt(arg)); + } + return InitializeBoundaryBlock(block, boundary_state, arg_states); + } - // This virtual method is the transfer function for an operation. It is called - // by its overloaded protected method. Remember, what the input and output - // should come from is different in forward and backward analysis. In forward - // analysis, the input should be the state before the op. In backward - // analysis, it should be the state after the op. - // +--------+-------------------+-------------------+ - // | | Forward Analysis | Backward Analysis | - // +--------+-------------------+-------------------+ - // | Input | Before | After | - // +--------+-------------------+-------------------+ - // | Output | After | Before | - // +--------+-------------------+-------------------+ - // - // Input: - // - The analysis state (lattice value) input for the op - // - // Output: - // - The analysis state (lattice value) output for the op - virtual void VisitOp(mlir::Operation *op, const StateT *input, - JsirStateRef output) = 0; + // The initial state on a boundary `mlir::Value`, e.g. a parameter of an entry + // block. This is used in both backward and forward analysis, when visiting + // the CFG edges. + virtual ValueT BoundaryInitialValue() const = 0; + + // Sets the initial state on a boundary `mlir::Block`, i.e. the entry state of + // an entry block for a forward analysis, or the exit state of an exit block + // for a backward analysis. + virtual void InitializeBoundaryBlock( + mlir::Block *block, JsirStateRef boundary_state, + llvm::MutableArrayRef> arg_states) = 0; // Gets the state attached before an op. JsirStateRef GetStateBefore(mlir::Operation *op) final; @@ -332,6 +327,38 @@ class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, // Gets the state attached at the end of a block. JsirStateRef GetStateAtEndOf(mlir::Block *block) final; + // This virtual method is the transfer function for an operation. It is called + // by its overloaded protected method. Same as its version in dense analysis, + // what the input and output of sparse states (`ValueT`) should come from + // is different in forward and backward analysis. + // + // Generally, a transfer function in a dataflow analysis can be represented in + // a form of + // + // output = gen ∪ (input - kill) + // + // where ∪ is the lattice join operation. Usually, for forward analysis, the + // `gen` set comes from the `results` in a JSIR `Operation`, and `kill` set + // comes from the `operands`. For backward analysis, it is the opposite case. + // + // For sparse values, we would update the values in `gen` set, and read values + // from `kill` set. Thus, we have the following table for sparse values: + // +--------+-------------------+-------------------+ + // | | Forward Analysis | Backward Analysis | + // +--------+-------------------+-------------------+ + // | Input | Operands | Results | + // +--------+-------------------+-------------------+ + // | Output | Results | Operands | + // +--------+-------------------+-------------------+ + virtual void VisitOp( + mlir::Operation *op, llvm::ArrayRef sparse_input, + const StateT *dense_input, + llvm::MutableArrayRef> sparse_output, + JsirStateRef dense_output) = 0; + + // Gets the state at an SSA value. + JsirStateRef GetStateAt(mlir::Value value) final; + // Format: // // ^block_name: @@ -346,6 +373,12 @@ class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, void PrintRegion(mlir::Region ®ion, size_t num_indents, mlir::AsmState &asm_state, llvm::raw_ostream &os); + // Callbacks for `PrintOp`. See comments of `PrintOp` for the format. + virtual void PrintAtBlockEntry(mlir::Block &block, size_t num_indents, + llvm::raw_ostream &os); + virtual void PrintAfterOp(mlir::Operation *op, size_t num_indents, + mlir::AsmState &asm_state, llvm::raw_ostream &os); + bool IsEntryBlock(mlir::Block *block); // When we visit the op, visit all the CFG edges associated with that op. @@ -357,79 +390,52 @@ class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, block_to_cfg_edges_; protected: - JsirGeneralCfgEdge *GetCfgEdge( - mlir::ProgramPoint *pred, mlir::ProgramPoint *succ, - std::optional> liveness_info, - llvm::SmallVector pred_values, - llvm::SmallVector succ_values); - - void MaybeEmplaceCfgEdge(mlir::ProgramPoint *from, mlir::ProgramPoint *to, - mlir::Operation *op, - std::optional> - liveness_info = std::nullopt, - llvm::SmallVector pred_values = {}, - llvm::SmallVector succ_values = {}); - - static llvm::SmallVector GetPredValuesEmpty(mlir::Block *block) { - return {}; - } + struct CfgEdgeOptions { + llvm::SmallVector from; + llvm::SmallVector to; + mlir::Operation *owner; + std::optional liveness_info; + std::variant> + pred_values; + mlir::ValueRange succ_values; + }; - void MaybeEmplaceCfgEdgesFromRegion( - mlir::Region &from_exits_of, mlir::ProgramPoint *to, mlir::Operation *op, - std::optional> liveness_info = - std::nullopt, - absl::FunctionRef(mlir::Block *)> - get_pred_values = GetPredValuesEmpty, - llvm::SmallVector succ_values = {}); - - void MaybeEmplaceCfgEdgesBetweenRegions( - mlir::Region &from_exits_of, mlir::Region &to_entry_of, - mlir::Operation *op, - std::optional> liveness_info = - std::nullopt, - absl::FunctionRef(mlir::Block *)> - get_pred_values = GetPredValuesEmpty, - llvm::SmallVector succ_values = {}); - - static llvm::SmallVector GetExprRegionEndValues( - mlir::Block *block) { - auto term_op = block->getTerminator(); - if (auto expr_region_end_op = - llvm::dyn_cast(term_op)) { - return {expr_region_end_op.getArgument()}; - } - return llvm::SmallVector{}; - } + void MaybeEmplaceCfgEdges(CfgEdgeOptions options) { + for (auto &from : options.from) { + for (auto &to : options.to) { + for (auto &op : *from->getBlock()) { + if (getProgramPointBefore(&op) == from) { + break; + } + if (llvm::isa(op) || + llvm::isa(op)) { + return; + } + } - static llvm::SmallVector GetExprRegionEndValuesFromRegion( - mlir::Region ®ion) { - for (auto &block : region.getBlocks()) { - if (block.hasNoSuccessors()) { - auto end_values = GetExprRegionEndValues(&block); - if (!end_values.empty()) { - return end_values; + mlir::ValueRange pred_values; + if (std::holds_alternative(options.pred_values)) { + pred_values = std::get(options.pred_values); + } else { + pred_values = + std::get>( + options.pred_values)(from->getBlock()); } - } - } - return llvm::SmallVector{}; - } - static mlir::Region &GetForStatementContinueTargetRegion( - JshirForStatementOp for_stmt) { - if (!for_stmt.getUpdate().empty()) { - return for_stmt.getUpdate(); - } - if (!for_stmt.getTest().empty()) { - return for_stmt.getTest(); + JsirGeneralCfgEdge *edge = getLatticeAnchor( + from, to, pred_values, options.succ_values, options.liveness_info); + if (options.owner != nullptr) { + op_to_cfg_edges_[options.owner].push_back(edge); + auto from_state = GetStateImpl(from); + from_state.AddDependent(getProgramPointAfter(options.owner)); + } else { + block_to_cfg_edges_[edge->getSucc()->getBlock()].push_back(edge); + } + } } - return for_stmt.getBody(); } - // Gets the state at the program point. - template - JsirStateRef GetStateImpl(mlir::LatticeAnchor anchor); - - mlir::LogicalResult initialize(mlir::Operation *op) override; void InitializeBlock(mlir::Block *block); // Since our analysis algorithm is based on MLIR's dataflow analysis, we need @@ -442,8 +448,8 @@ class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, // provides all successors as dependencies. // This method is called inside `InitializeBlock`. virtual void InitializeBlockDependencies(mlir::Block *block); + virtual void VisitBlock(mlir::Block *block); - virtual void VisitOp(mlir::Operation *op); // This method mainly serves to "join" states from blocks. i.e., this method // should implement the "join" operation in a dataflow analysis. It should @@ -452,96 +458,90 @@ class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, // block to the end of the predecessor for a backward analysis. virtual void VisitCfgEdge(JsirGeneralCfgEdge *edge); - // Callbacks for `PrintOp`. See comments of `PrintOp` for the format. - virtual void PrintAtBlockEntry(mlir::Block &block, size_t num_indents, - llvm::raw_ostream &os); - virtual void PrintAfterOp(mlir::Operation *op, size_t num_indents, - mlir::AsmState &asm_state, llvm::raw_ostream &os); + // Gets the state at the program point. + template + JsirStateRef GetStateImpl(mlir::LatticeAnchor anchor); - mlir::DataFlowSolver &solver_; + // Helper function to get the `StateRef`s for the operands and results of an + // op. For forward analysis, the input should be the operands and the output + // should be the results. For backward analysis, the input should be the + // results and the output should be the operands. + struct ValueStateRefs { + std::vector inputs; + std::vector> outputs; + }; + ValueStateRefs GetValueStateRefs(mlir::Operation *op); - private: - // TODO(b/425421947) Could this be not a member variable? - JumpEnv jump_env_; + llvm::SmallVector Before(mlir::Operation *op) { + return {getProgramPointBefore(op)}; + } - // TODO(b/425421947) It would be nice to refactor the jump environment so this - // logic can go in the cases for each op in `initialize`. - std::optional GetJumpTargets(mlir::Operation *op) { - if (auto with_stmt = llvm::dyn_cast(op); - with_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(with_stmt), - .unlabeled_break_target = std::nullopt, - .continue_target = std::nullopt, - }; - } - if (auto if_stmt = llvm::dyn_cast(op); - if_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(if_stmt), - .unlabeled_break_target = std::nullopt, - .continue_target = std::nullopt, - }; - } - if (auto switch_stmt = llvm::dyn_cast(op); - switch_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(switch_stmt), - .unlabeled_break_target = getProgramPointAfter(switch_stmt), - .continue_target = std::nullopt, - }; - } - if (auto while_stmt = llvm::dyn_cast(op); - while_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(while_stmt), - .unlabeled_break_target = getProgramPointAfter(while_stmt), - .continue_target = - getProgramPointBefore(&while_stmt.getTest().front()), - }; + llvm::SmallVector After(mlir::Operation *op) { + return {getProgramPointAfter(op)}; + } + + llvm::SmallVector Before(mlir::Block *block) { + return {getProgramPointBefore(block)}; + } + + llvm::SmallVector After(mlir::Block *block) { + return {getProgramPointAfter(block)}; + } + + llvm::SmallVector Before(mlir::Region ®ion) { + CHECK(!region.empty()); + return {getProgramPointBefore(®ion.front())}; + } + + llvm::SmallVector After(mlir::Region ®ion) { + llvm::SmallVector after_points; + for (mlir::Block &block : region) { + if (block.getSuccessors().empty()) { + after_points.push_back(getProgramPointAfter(&block)); + } } - if (auto do_while_stmt = llvm::dyn_cast(op); - do_while_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(do_while_stmt), - .unlabeled_break_target = getProgramPointAfter(do_while_stmt), - .continue_target = - getProgramPointBefore(&do_while_stmt.getTest().front()), - }; + return after_points; + } + + static mlir::ValueRange GetExprRegionEndValues(mlir::Block *block) { + auto term_op = block->getTerminator(); + if (auto expr_region_end_op = + llvm::dyn_cast(term_op)) { + return expr_region_end_op->getOperands(); } - if (auto for_stmt = llvm::dyn_cast(op); - for_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(for_stmt), - .unlabeled_break_target = getProgramPointAfter(for_stmt), - .continue_target = getProgramPointBefore( - &GetForStatementContinueTargetRegion(for_stmt).front()), - }; + return {}; + } + + static mlir::ValueRange GetExprRegionEndValuesFromRegion( + mlir::Region ®ion) { + for (auto &block : region.getBlocks()) { + if (block.hasNoSuccessors()) { + auto end_values = GetExprRegionEndValues(&block); + if (!end_values.empty()) { + return end_values; + } + } } - if (auto for_in_stmt = llvm::dyn_cast(op); - for_in_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(for_in_stmt), - .unlabeled_break_target = getProgramPointAfter(for_in_stmt), - .continue_target = - getProgramPointBefore(&for_in_stmt.getBody().front()), - }; + return {}; + } + + static mlir::Region &GetForStatementContinueTargetRegion( + JshirForStatementOp for_stmt) { + if (!for_stmt.getUpdate().empty()) { + return for_stmt.getUpdate(); } - if (auto for_of_stmt = llvm::dyn_cast(op); - for_of_stmt != nullptr) { - return JumpTargets{ - .labeled_break_target = getProgramPointAfter(for_of_stmt), - .unlabeled_break_target = getProgramPointAfter(for_of_stmt), - .continue_target = - getProgramPointBefore(&for_of_stmt.getBody().front()), - }; + if (!for_stmt.getTest().empty()) { + return for_stmt.getTest(); } - return std::nullopt; + return for_stmt.getBody(); } + private: + // TODO(b/425421947) Could this be not a member variable? + JumpEnv jump_env_; + std::optional WithJumpTargets( - mlir::Operation *op) { - auto maybe_jump_targets = GetJumpTargets(op); + std::optional maybe_jump_targets) { if (maybe_jump_targets.has_value()) { return jump_env_.WithJumpTargets(maybe_jump_targets.value()); } @@ -557,126 +557,26 @@ class JsirDenseDataFlowAnalysis : public mlir::DataFlowAnalysis, return std::nullopt; } - // Override `mlir::DataFlowAnalysis::visit` and redirect to `Visit{Op,Block}`. - mlir::LogicalResult visit(mlir::ProgramPoint *point) override; -}; - -template -using JsirDenseForwardDataFlowAnalysis = - JsirDenseDataFlowAnalysis; - -template -using JsirDenseBackwardDataFlowAnalysis = - JsirDenseDataFlowAnalysis; - -// ============================================================================= -// JsirDataFlowAnalysis -// ============================================================================= -// A dataflow analysis API that attaches lattices to both values and operations. -// This analysis supports both forward and backward analysis. -template -class JsirDataFlowAnalysis - : public JsirDenseDataFlowAnalysis, - public JsirSparseStates> { - public: - using Base = JsirDenseDataFlowAnalysis; - - explicit JsirDataFlowAnalysis(mlir::DataFlowSolver &solver) - : JsirDenseDataFlowAnalysis(solver) {} - - // The initial state on a boundary `mlir::Value`, e.g. a parameter of an entry - // block. This is used in both backward and forward analysis, when visiting - // the CFG edges. - virtual ValueT BoundaryInitialValue() const = 0; - - // Sets the initial state on a boundary `mlir::Block`, i.e. the entry state of - // an entry block for a forward analysis, or the exit state of an exit block - // for a backward analysis. - virtual void InitializeBoundaryBlock( - mlir::Block *block, JsirStateRef boundary_state, - llvm::MutableArrayRef> arg_states) = 0; - - using Base::InitializeBoundaryBlock; - void InitializeBoundaryBlock(mlir::Block *block, - JsirStateRef boundary_state) override { - std::vector> arg_states; - for (mlir::Value arg : block->getArguments()) { - arg_states.push_back(GetStateAt(arg)); - } - return InitializeBoundaryBlock(block, boundary_state, arg_states); - } - - // This virtual method is the transfer function for an operation. It is called - // by its overloaded protected method. Same as its version in dense analysis, - // what the input and output of sparse states (`ValueT`) should come from - // is different in forward and backward analysis. - // - // Generally, a transfer function in a dataflow analysis can be represented in - // a form of - // - // output = gen ∪ (input - kill) - // - // where ∪ is the lattice join operation. Usually, for forward analysis, the - // `gen` set comes from the `results` in a JSIR `Operation`, and `kill` set - // comes from the `operands`. For backward analysis, it is the opposite case. - // - // For sparse values, we would update the values in `gen` set, and read values - // from `kill` set. Thus, we have the following table for sparse values: - // +--------+-------------------+-------------------+ - // | | Forward Analysis | Backward Analysis | - // +--------+-------------------+-------------------+ - // | Input | Operands | Results | - // +--------+-------------------+-------------------+ - // | Output | Results | Operands | - // +--------+-------------------+-------------------+ - virtual void VisitOp( - mlir::Operation *op, llvm::ArrayRef sparse_input, - const StateT *dense_input, - llvm::MutableArrayRef> sparse_output, - JsirStateRef dense_output) = 0; - - // Gets the state at an SSA value. - JsirStateRef GetStateAt(mlir::Value value) final; - - protected: - using Base::GetCfgEdge; - using Base::InitializeBlockDependencies; - using Base::VisitBlock; - using Base::VisitOp; - void VisitCfgEdge(JsirGeneralCfgEdge *edge) override; - - // Helper function to get the `StateRef`s for the operands and results of an - // op. For forward analysis, the input should be the operands and the output - // should be the results. For backward analysis, the input should be the - // results and the output should be the operands. - struct ValueStateRefs { - std::vector inputs; - std::vector> outputs; - }; - ValueStateRefs GetValueStateRefs(mlir::Operation *op); - - void PrintAfterOp(mlir::Operation *op, size_t num_indents, - mlir::AsmState &asm_state, llvm::raw_ostream &os) override; - - private: mlir::LogicalResult initialize(mlir::Operation *op) override; - // Override the transfer function in `JsirDenseDataFlowAnalysis` and - // redirect to the transfer function supporting sparse values in - // `JsirDataFlowAnalysis`. void VisitOp(mlir::Operation *op, const StateT *input, - JsirStateRef output) override; + JsirStateRef output); + + virtual void VisitOp(mlir::Operation *op); - using Base::solver_; + mlir::DataFlowSolver &solver_; + + // Override `mlir::DataFlowAnalysis::visit` and redirect to `Visit{Op,Block}`. + mlir::LogicalResult visit(mlir::ProgramPoint *point) override; }; -template +template using JsirForwardDataFlowAnalysis = - JsirDataFlowAnalysis; + JsirDataFlowAnalysis; -template +template using JsirBackwardDataFlowAnalysis = - JsirDataFlowAnalysis; + JsirDataFlowAnalysis; // ============================================================================= // JsirStateRef @@ -717,86 +617,21 @@ void JsirStateRef::Join(const T &lattice) { } // ============================================================================= -// JsirDenseDataFlowAnalysis +// JsirDataFlowAnalysis // ============================================================================= -template -JsirGeneralCfgEdge *JsirDenseDataFlowAnalysis::GetCfgEdge( - mlir::ProgramPoint *pred, mlir::ProgramPoint *succ, - std::optional> liveness_info, - llvm::SmallVector pred_values, - llvm::SmallVector succ_values) { - return getLatticeAnchor(pred, succ, pred_values, - succ_values, liveness_info); -} - -template -void JsirDenseDataFlowAnalysis::MaybeEmplaceCfgEdge( - mlir::ProgramPoint *from, mlir::ProgramPoint *to, mlir::Operation *op, - std::optional> liveness_info, - llvm::SmallVector pred_values, - llvm::SmallVector succ_values) { - for (auto &op : *from->getBlock()) { - if (getProgramPointBefore(&op) == from) { - break; - } - if (llvm::isa(op) || - llvm::isa(op)) { - return; - } - } - - op_to_cfg_edges_[op].push_back( - GetCfgEdge(from, to, liveness_info, pred_values, succ_values)); - auto from_state = GetStateImpl(from); - from_state.AddDependent(getProgramPointAfter(op)); -} - -template -void JsirDenseDataFlowAnalysis:: - MaybeEmplaceCfgEdgesFromRegion( - mlir::Region &from_exits_of, mlir::ProgramPoint *to, - mlir::Operation *op, - std::optional> liveness_info, - absl::FunctionRef(mlir::Block *)> - get_pred_values, - llvm::SmallVector succ_values) { - for (mlir::Block &block : from_exits_of) { - if (block.getSuccessors().empty()) { - mlir::ProgramPoint *after_block = getProgramPointAfter(&block); - MaybeEmplaceCfgEdge(after_block, to, op, liveness_info, - get_pred_values(&block), succ_values); - } - } -} - -template -void JsirDenseDataFlowAnalysis:: - MaybeEmplaceCfgEdgesBetweenRegions( - mlir::Region &from_exits_of, mlir::Region &to_entry_of, - mlir::Operation *op, - std::optional> liveness_info, - absl::FunctionRef(mlir::Block *)> - get_pred_values, - llvm::SmallVector succ_values) { - CHECK(!from_exits_of.empty()); - mlir::ProgramPoint *entry = getProgramPointBefore(&to_entry_of.front()); - MaybeEmplaceCfgEdgesFromRegion(from_exits_of, entry, op, liveness_info, - get_pred_values, succ_values); -} - -template +template template -JsirStateRef JsirDenseDataFlowAnalysis::GetStateImpl( +JsirStateRef JsirDataFlowAnalysis::GetStateImpl( mlir::LatticeAnchor anchor) { auto *element = mlir::DataFlowAnalysis::getOrCreate>(anchor); return JsirStateRef{element, &solver_, this}; } -template +template JsirStateRef -JsirDenseDataFlowAnalysis::GetStateBefore( +JsirDataFlowAnalysis::GetStateBefore( mlir::Operation *op) { if (auto *prev_op = op->getPrevNode()) { return GetStateAfter(prev_op); @@ -805,23 +640,23 @@ JsirDenseDataFlowAnalysis::GetStateBefore( } } -template +template JsirStateRef -JsirDenseDataFlowAnalysis::GetStateAfter( +JsirDataFlowAnalysis::GetStateAfter( mlir::Operation *op) { return GetStateImpl(getProgramPointAfter(op)); } -template +template JsirStateRef -JsirDenseDataFlowAnalysis::GetStateAtEntryOf( +JsirDataFlowAnalysis::GetStateAtEntryOf( mlir::Block *block) { return GetStateImpl(getProgramPointBefore(block)); } -template +template JsirStateRef -JsirDenseDataFlowAnalysis::GetStateAtEndOf( +JsirDataFlowAnalysis::GetStateAtEndOf( mlir::Block *block) { if (block->empty()) { return GetStateAtEntryOf(block); @@ -830,8 +665,8 @@ JsirDenseDataFlowAnalysis::GetStateAtEndOf( } } -template -void JsirDenseDataFlowAnalysis::PrintOp( +template +void JsirDataFlowAnalysis::PrintOp( mlir::Operation *op, size_t num_indents, mlir::AsmState &asm_state, llvm::raw_ostream &os) { size_t num_results = op->getNumResults(); @@ -888,8 +723,8 @@ void JsirDenseDataFlowAnalysis::PrintOp( PrintAfterOp(op, num_indents, asm_state, os); } -template -void JsirDenseDataFlowAnalysis::PrintRegion( +template +void JsirDataFlowAnalysis::PrintRegion( mlir::Region ®ion, size_t num_indents, mlir::AsmState &asm_state, llvm::raw_ostream &os) { os << "{\n"; @@ -918,8 +753,8 @@ void JsirDenseDataFlowAnalysis::PrintRegion( os << "}"; } -template -bool JsirDenseDataFlowAnalysis::IsEntryBlock( +template +bool JsirDataFlowAnalysis::IsEntryBlock( mlir::Block *block) { mlir::Operation *parent_op = block->getParentOp(); @@ -947,9 +782,15 @@ bool JsirDenseDataFlowAnalysis::IsEntryBlock( // JsirExecutable. // - On every CFG edge (Block -> Block): // JsirExecutable. -template -mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( +template +mlir::LogicalResult JsirDataFlowAnalysis::initialize( mlir::Operation *op) { + // The op depends on its input operands. + for (mlir::Value operand : op->getOperands()) { + JsirStateRef operand_state_ref = GetStateAt(operand); + operand_state_ref.AddDependent(getProgramPointAfter(op)); + } + // Register `op`'s dependent state. if (op->getParentOp() != nullptr) { if constexpr (direction == DataflowDirection::kForward) { @@ -961,48 +802,40 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( } } - if (auto branch = llvm::dyn_cast(op); branch != nullptr) { - mlir::ProgramPoint *after_branch = getProgramPointAfter(branch); - mlir::ProgramPoint *before_dest = getProgramPointBefore(branch.getDest()); - - llvm::SmallVector succ_values = { - branch.getDest()->getArguments().begin(), - branch.getDest()->getArguments().end()}; + std::optional maybe_jump_targets; - auto *edge = GetCfgEdge(after_branch, before_dest, std::nullopt, - branch.getDestOperands(), succ_values); - block_to_cfg_edges_[edge->getSucc()->getBlock()].push_back(edge); + if (auto branch = llvm::dyn_cast(op); branch != nullptr) { + MaybeEmplaceCfgEdges({ + .from = After(branch), + .to = Before(branch.getDest()), + .pred_values = branch.getDestOperands(), + .succ_values = branch.getDest()->getArguments(), + }); } if (auto cond_branch = llvm::dyn_cast(op); cond_branch != nullptr) { - mlir::ProgramPoint *after_cond_branch = getProgramPointAfter(cond_branch); - mlir::ProgramPoint *before_true_dest = - getProgramPointBefore(cond_branch.getTrueDest()); - mlir::ProgramPoint *before_false_dest = - getProgramPointBefore(cond_branch.getFalseDest()); - - llvm::SmallVector true_succ_values = { - cond_branch.getTrueDest()->getArguments().begin(), - cond_branch.getTrueDest()->getArguments().end()}; - llvm::SmallVector false_succ_values = { - cond_branch.getFalseDest()->getArguments().begin(), - cond_branch.getFalseDest()->getArguments().end()}; - - auto *true_edge = - GetCfgEdge(after_cond_branch, before_true_dest, - std::tuple{cond_branch.getCondition(), - LivenessKind::kLiveIfTrueOrUnknown}, - cond_branch.getTrueDestOperands(), true_succ_values); - block_to_cfg_edges_[true_edge->getSucc()->getBlock()].push_back(true_edge); - - auto *false_edge = - GetCfgEdge(after_cond_branch, before_false_dest, - std::tuple{cond_branch.getCondition(), - LivenessKind::kLiveIfFalseOrUnknown}, - cond_branch.getFalseDestOperands(), false_succ_values); - block_to_cfg_edges_[false_edge->getSucc()->getBlock()].push_back( - false_edge); + MaybeEmplaceCfgEdges({ + .from = After(cond_branch), + .to = Before(cond_branch.getTrueDest()), + .liveness_info = + std::tuple{cond_branch.getCondition(), + mlir::BoolAttr::get(cond_branch.getContext(), true), + LivenessKind::kLiveIfEqualOrUnknown}, + .pred_values = cond_branch.getTrueDestOperands(), + .succ_values = cond_branch.getTrueDest()->getArguments(), + }); + + MaybeEmplaceCfgEdges({ + .from = After(cond_branch), + .to = Before(cond_branch.getFalseDest()), + .liveness_info = + std::tuple{cond_branch.getCondition(), + mlir::BoolAttr::get(cond_branch.getContext(), false), + LivenessKind::kLiveIfEqualOrUnknown}, + .pred_values = cond_branch.getFalseDestOperands(), + .succ_values = cond_branch.getFalseDest()->getArguments(), + }); } // Handle ops with a single region. @@ -1018,17 +851,30 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( llvm::isa(op) /* TODO Should this be here? */ || llvm::isa(op) || llvm::isa(op)) { - mlir::ProgramPoint *before_op = getProgramPointBefore(op); - mlir::ProgramPoint *after_op = getProgramPointAfter(op); - + if (llvm::isa(op)) { + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(op), + .unlabeled_break_target = std::nullopt, + .continue_target = std::nullopt, + }; + } if (!op->getRegion(0).empty()) { - mlir::ProgramPoint *before_region = - getProgramPointBefore(&op->getRegion(0).front()); - - MaybeEmplaceCfgEdge(before_op, before_region, op); - MaybeEmplaceCfgEdgesFromRegion(op->getRegion(0), after_op, op); + MaybeEmplaceCfgEdges({ + .from = Before(op), + .to = Before(op->getRegion(0)), + .owner = op, + }); + MaybeEmplaceCfgEdges({ + .from = After(op->getRegion(0)), + .to = After(op), + .owner = op, + }); } else { - MaybeEmplaceCfgEdge(before_op, after_op, op); + MaybeEmplaceCfgEdges({ + .from = Before(op), + .to = After(op), + .owner = op, + }); } } @@ -1044,30 +890,51 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └──► if (auto if_stmt = llvm::dyn_cast(op); if_stmt != nullptr) { - mlir::ProgramPoint *before_if_stmt = getProgramPointBefore(if_stmt); - mlir::ProgramPoint *after_if_stmt = getProgramPointAfter(if_stmt); - mlir::ProgramPoint *before_consequent = - getProgramPointBefore(&if_stmt.getConsequent().front()); - - MaybeEmplaceCfgEdge( - before_if_stmt, before_consequent, if_stmt, - std::tuple{if_stmt.getTest(), LivenessKind::kLiveIfTrueOrUnknown}); - MaybeEmplaceCfgEdgesFromRegion(if_stmt.getConsequent(), after_if_stmt, - if_stmt); + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(if_stmt), + .unlabeled_break_target = std::nullopt, + .continue_target = std::nullopt, + }; + MaybeEmplaceCfgEdges({ + .from = Before(if_stmt), + .to = Before(if_stmt.getConsequent()), + .owner = if_stmt, + .liveness_info = + std::tuple{if_stmt.getTest(), + mlir::BoolAttr::get(if_stmt.getContext(), true), + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = After(if_stmt.getConsequent()), + .to = After(if_stmt), + .owner = if_stmt, + }); if (!if_stmt.getAlternate().empty()) { - mlir::ProgramPoint *before_alternate = - getProgramPointBefore(&if_stmt.getAlternate().front()); - - MaybeEmplaceCfgEdge( - before_if_stmt, before_alternate, if_stmt, - std::tuple{if_stmt.getTest(), LivenessKind::kLiveIfFalseOrUnknown}); - MaybeEmplaceCfgEdgesFromRegion(if_stmt.getAlternate(), after_if_stmt, - if_stmt); + MaybeEmplaceCfgEdges({ + .from = Before(if_stmt), + .to = Before(if_stmt.getAlternate()), + .owner = if_stmt, + .liveness_info = + std::tuple{if_stmt.getTest(), + mlir::BoolAttr::get(if_stmt.getContext(), false), + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = After(if_stmt.getAlternate()), + .to = After(if_stmt), + .owner = if_stmt, + }); } else { - MaybeEmplaceCfgEdge( - before_if_stmt, after_if_stmt, if_stmt, - std::tuple{if_stmt.getTest(), LivenessKind::kLiveIfFalseOrUnknown}); + MaybeEmplaceCfgEdges({ + .from = Before(if_stmt), + .to = After(if_stmt), + .owner = if_stmt, + .liveness_info = + std::tuple{if_stmt.getTest(), + mlir::BoolAttr::get(if_stmt.getContext(), false), + LivenessKind::kLiveIfEqualOrUnknown}, + }); } } @@ -1083,16 +950,21 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └─────► if (auto block_stmt = llvm::dyn_cast(op); block_stmt != nullptr) { - mlir::ProgramPoint *before_block_stmt = getProgramPointBefore(block_stmt); - mlir::ProgramPoint *after_block_stmt = getProgramPointAfter(block_stmt); - mlir::ProgramPoint *before_directives = - getProgramPointBefore(&block_stmt.getDirectives().front()); - - MaybeEmplaceCfgEdge(before_block_stmt, before_directives, block_stmt); - MaybeEmplaceCfgEdgesBetweenRegions(block_stmt.getDirectives(), - block_stmt.getBody(), block_stmt); - MaybeEmplaceCfgEdgesFromRegion(block_stmt.getBody(), after_block_stmt, - block_stmt); + MaybeEmplaceCfgEdges({ + .from = Before(block_stmt), + .to = Before(block_stmt.getDirectives()), + .owner = block_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(block_stmt.getDirectives()), + .to = Before(block_stmt.getBody()), + .owner = block_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(block_stmt.getBody()), + .to = After(block_stmt), + .owner = block_stmt, + }); } // ┌─────◄ @@ -1107,22 +979,41 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └──► if (auto while_stmt = llvm::dyn_cast(op); while_stmt != nullptr) { - mlir::ProgramPoint *before_while_stmt = getProgramPointBefore(while_stmt); - mlir::ProgramPoint *after_while_stmt = getProgramPointAfter(while_stmt); - mlir::ProgramPoint *before_test = - getProgramPointBefore(&while_stmt.getTest().front()); - - MaybeEmplaceCfgEdge(before_while_stmt, before_test, while_stmt); - MaybeEmplaceCfgEdgesBetweenRegions( - while_stmt.getTest(), while_stmt.getBody(), while_stmt, - std::tuple{GetExprRegionEndValuesFromRegion(while_stmt.getTest())[0], - LivenessKind::kLiveIfTrueOrUnknown}); - MaybeEmplaceCfgEdgesBetweenRegions(while_stmt.getBody(), - while_stmt.getTest(), while_stmt); - MaybeEmplaceCfgEdgesFromRegion( - while_stmt.getTest(), after_while_stmt, while_stmt, - std::tuple{GetExprRegionEndValuesFromRegion(while_stmt.getTest())[0], - LivenessKind::kLiveIfFalseOrUnknown}); + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(while_stmt), + .unlabeled_break_target = getProgramPointAfter(while_stmt), + .continue_target = getProgramPointBefore(&while_stmt.getTest().front()), + }; + MaybeEmplaceCfgEdges({ + .from = Before(while_stmt), + .to = Before(while_stmt.getTest()), + .owner = while_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(while_stmt.getTest()), + .to = Before(while_stmt.getBody()), + .owner = while_stmt, + .liveness_info = + std::tuple{ + GetExprRegionEndValuesFromRegion(while_stmt.getTest())[0], + mlir::BoolAttr::get(while_stmt.getContext(), true), + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = After(while_stmt.getBody()), + .to = Before(while_stmt.getTest()), + .owner = while_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(while_stmt.getTest()), + .to = After(while_stmt), + .owner = while_stmt, + .liveness_info = + std::tuple{ + GetExprRegionEndValuesFromRegion(while_stmt.getTest())[0], + mlir::BoolAttr::get(while_stmt.getContext(), false), + LivenessKind::kLiveIfEqualOrUnknown}, + }); } // ┌─────◄ @@ -1137,24 +1028,42 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └─────► if (auto do_while_stmt = llvm::dyn_cast(op); do_while_stmt != nullptr) { - mlir::ProgramPoint *before_do_while_stmt = - getProgramPointBefore(do_while_stmt); - mlir::ProgramPoint *after_do_while_stmt = - getProgramPointAfter(do_while_stmt); - mlir::ProgramPoint *before_body = - getProgramPointBefore(&do_while_stmt.getBody().front()); - - MaybeEmplaceCfgEdge(before_do_while_stmt, before_body, do_while_stmt); - MaybeEmplaceCfgEdgesBetweenRegions(do_while_stmt.getBody(), - do_while_stmt.getTest(), do_while_stmt); - MaybeEmplaceCfgEdgesBetweenRegions( - do_while_stmt.getTest(), do_while_stmt.getBody(), do_while_stmt, - std::tuple{GetExprRegionEndValuesFromRegion(do_while_stmt.getTest())[0], - LivenessKind::kLiveIfTrueOrUnknown}); - MaybeEmplaceCfgEdgesFromRegion( - do_while_stmt.getTest(), after_do_while_stmt, do_while_stmt, - std::tuple{GetExprRegionEndValuesFromRegion(do_while_stmt.getTest())[0], - LivenessKind::kLiveIfFalseOrUnknown}); + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(do_while_stmt), + .unlabeled_break_target = getProgramPointAfter(do_while_stmt), + .continue_target = + getProgramPointBefore(&do_while_stmt.getTest().front()), + }; + MaybeEmplaceCfgEdges({ + .from = Before(do_while_stmt), + .to = Before(do_while_stmt.getBody()), + .owner = do_while_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(do_while_stmt.getBody()), + .to = Before(do_while_stmt.getTest()), + .owner = do_while_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(do_while_stmt.getTest()), + .to = Before(do_while_stmt.getBody()), + .owner = do_while_stmt, + .liveness_info = + std::tuple{ + GetExprRegionEndValuesFromRegion(do_while_stmt.getTest())[0], + mlir::BoolAttr::get(do_while_stmt.getContext(), true), + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = After(do_while_stmt.getTest()), + .to = After(do_while_stmt), + .owner = do_while_stmt, + .liveness_info = + std::tuple{ + GetExprRegionEndValuesFromRegion(do_while_stmt.getTest())[0], + mlir::BoolAttr::get(do_while_stmt.getContext(), false), + LivenessKind::kLiveIfEqualOrUnknown}, + }); } // ┌─────◄ @@ -1175,54 +1084,75 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └─────► if (auto for_stmt = llvm::dyn_cast(op); for_stmt != nullptr) { - mlir::ProgramPoint *before_for_stmt = getProgramPointBefore(for_stmt); - mlir::ProgramPoint *after_for_stmt = getProgramPointAfter(for_stmt); - + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(for_stmt), + .unlabeled_break_target = getProgramPointAfter(for_stmt), + .continue_target = getProgramPointBefore( + &GetForStatementContinueTargetRegion(for_stmt).front()), + }; // Emplace an edge into the first non-empty region of the for-statement. mlir::Region &first_region = !for_stmt.getInit().empty() ? for_stmt.getInit() : (!for_stmt.getTest().empty() ? for_stmt.getTest() : for_stmt.getBody()); - mlir::ProgramPoint *before_first_region = - getProgramPointBefore(&first_region.front()); - MaybeEmplaceCfgEdge(before_for_stmt, before_first_region, for_stmt); + MaybeEmplaceCfgEdges({ + .from = Before(for_stmt), + .to = Before(first_region), + .owner = for_stmt, + }); if (!for_stmt.getInit().empty()) { - // Init; mlir::Region &successor = !for_stmt.getTest().empty() ? for_stmt.getTest() : for_stmt.getBody(); - MaybeEmplaceCfgEdgesBetweenRegions(for_stmt.getInit(), successor, - for_stmt); + MaybeEmplaceCfgEdges({ + .from = After(for_stmt.getInit()), + .to = Before(successor), + .owner = for_stmt, + }); } if (!for_stmt.getTest().empty()) { - // Test - MaybeEmplaceCfgEdgesBetweenRegions( - for_stmt.getTest(), for_stmt.getBody(), for_stmt, - std::tuple{GetExprRegionEndValuesFromRegion(for_stmt.getTest())[0], - LivenessKind::kLiveIfTrueOrUnknown}); - MaybeEmplaceCfgEdgesFromRegion( - for_stmt.getTest(), after_for_stmt, for_stmt, - std::tuple{GetExprRegionEndValuesFromRegion(for_stmt.getTest())[0], - LivenessKind::kLiveIfFalseOrUnknown}); + MaybeEmplaceCfgEdges({ + .from = After(for_stmt.getTest()), + .to = Before(for_stmt.getBody()), + .owner = for_stmt, + .liveness_info = + std::tuple{ + GetExprRegionEndValuesFromRegion(for_stmt.getTest())[0], + mlir::BoolAttr::get(for_stmt.getContext(), true), + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = After(for_stmt.getTest()), + .to = After(for_stmt), + .owner = for_stmt, + .liveness_info = + std::tuple{ + GetExprRegionEndValuesFromRegion(for_stmt.getTest())[0], + mlir::BoolAttr::get(for_stmt.getContext(), false), + LivenessKind::kLiveIfEqualOrUnknown}, + }); } { - // Body - MaybeEmplaceCfgEdgesBetweenRegions( - for_stmt.getBody(), GetForStatementContinueTargetRegion(for_stmt), - for_stmt); + MaybeEmplaceCfgEdges({ + .from = After(for_stmt.getBody()), + .to = Before(GetForStatementContinueTargetRegion(for_stmt)), + .owner = for_stmt, + }); } if (!for_stmt.getUpdate().empty()) { - // Update mlir::Region &successor = !for_stmt.getTest().empty() ? for_stmt.getTest() : for_stmt.getBody(); - MaybeEmplaceCfgEdgesBetweenRegions(for_stmt.getUpdate(), successor, - for_stmt); + MaybeEmplaceCfgEdges({ + .from = After(for_stmt.getUpdate()), + .to = Before(successor), + .owner = for_stmt, + }); } } @@ -1235,17 +1165,27 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └─────► if (auto for_in_stmt = llvm::dyn_cast(op); for_in_stmt != nullptr) { - mlir::ProgramPoint *before_for_in_stmt = getProgramPointBefore(for_in_stmt); - mlir::ProgramPoint *after_for_in_stmt = getProgramPointAfter(for_in_stmt); - - mlir::ProgramPoint *before_body = - getProgramPointBefore(&for_in_stmt.getBody().front()); - - MaybeEmplaceCfgEdge(before_for_in_stmt, before_body, for_in_stmt); - MaybeEmplaceCfgEdgesBetweenRegions(for_in_stmt.getBody(), - for_in_stmt.getBody(), for_in_stmt); - MaybeEmplaceCfgEdgesFromRegion(for_in_stmt.getBody(), after_for_in_stmt, - for_in_stmt); + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(for_in_stmt), + .unlabeled_break_target = getProgramPointAfter(for_in_stmt), + .continue_target = + getProgramPointBefore(&for_in_stmt.getBody().front()), + }; + MaybeEmplaceCfgEdges({ + .from = Before(for_in_stmt), + .to = Before(for_in_stmt.getBody()), + .owner = for_in_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(for_in_stmt.getBody()), + .to = Before(for_in_stmt.getBody()), + .owner = for_in_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(for_in_stmt.getBody()), + .to = After(for_in_stmt), + .owner = for_in_stmt, + }); } // ┌─────◄ @@ -1257,17 +1197,27 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └─────► if (auto for_of_stmt = llvm::dyn_cast(op); for_of_stmt != nullptr) { - mlir::ProgramPoint *before_for_of_stmt = getProgramPointBefore(for_of_stmt); - mlir::ProgramPoint *after_for_of_stmt = getProgramPointAfter(for_of_stmt); - - mlir::ProgramPoint *before_body = - getProgramPointBefore(&for_of_stmt.getBody().front()); - - MaybeEmplaceCfgEdge(before_for_of_stmt, before_body, for_of_stmt); - MaybeEmplaceCfgEdgesBetweenRegions(for_of_stmt.getBody(), - for_of_stmt.getBody(), for_of_stmt); - MaybeEmplaceCfgEdgesFromRegion(for_of_stmt.getBody(), after_for_of_stmt, - for_of_stmt); + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(for_of_stmt), + .unlabeled_break_target = getProgramPointAfter(for_of_stmt), + .continue_target = + getProgramPointBefore(&for_of_stmt.getBody().front()), + }; + MaybeEmplaceCfgEdges({ + .from = Before(for_of_stmt), + .to = Before(for_of_stmt.getBody()), + .owner = for_of_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(for_of_stmt.getBody()), + .to = Before(for_of_stmt.getBody()), + .owner = for_of_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(for_of_stmt.getBody()), + .to = After(for_of_stmt), + .owner = for_of_stmt, + }); } // ┌─────◄ @@ -1279,43 +1229,47 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └──┴──► if (auto logical_expr = llvm::dyn_cast(op); logical_expr != nullptr) { - LivenessKind after_logical_expr_liveness_kind; - LivenessKind before_right_liveness_kind; + mlir::Attribute comparison_attr; switch (*StringToJsLogicalOperator(logical_expr.getOperator_())) { case JsLogicalOperator::kAnd: // left && right => left ? right : left - after_logical_expr_liveness_kind = LivenessKind::kLiveIfFalseOrUnknown; - before_right_liveness_kind = LivenessKind::kLiveIfTrueOrUnknown; + comparison_attr = mlir::BoolAttr::get(logical_expr.getContext(), true); break; case JsLogicalOperator::kOr: // left || right => left ? left : right - after_logical_expr_liveness_kind = LivenessKind::kLiveIfTrueOrUnknown; - before_right_liveness_kind = LivenessKind::kLiveIfFalseOrUnknown; + comparison_attr = mlir::BoolAttr::get(logical_expr.getContext(), false); break; case JsLogicalOperator::kNullishCoalesce: // left ?? right => (left == null) ? right : left - after_logical_expr_liveness_kind = - LivenessKind::kLiveIfNonNullOrUnknown; - before_right_liveness_kind = LivenessKind::kLiveIfNullOrUnknown; + comparison_attr = JsirNullLiteralAttr::get(logical_expr.getContext()); break; } - mlir::ProgramPoint *before_logical_expr = - getProgramPointBefore(logical_expr); - mlir::ProgramPoint *after_logical_expr = getProgramPointAfter(logical_expr); - mlir::ProgramPoint *before_right = - getProgramPointBefore(&logical_expr.getRight().front()); mlir::Value left_value = logical_expr.getLeft(); - MaybeEmplaceCfgEdge( - before_logical_expr, after_logical_expr, logical_expr, - std::tuple{left_value, after_logical_expr_liveness_kind}, {left_value}, - logical_expr->getResults()); - MaybeEmplaceCfgEdge(before_logical_expr, before_right, logical_expr, - std::tuple{left_value, before_right_liveness_kind}); - MaybeEmplaceCfgEdgesFromRegion( - logical_expr.getRight(), after_logical_expr, logical_expr, std::nullopt, - GetExprRegionEndValues, logical_expr->getResults()); + MaybeEmplaceCfgEdges({ + .from = Before(logical_expr), + .to = After(logical_expr), + .owner = logical_expr, + .liveness_info = std::tuple{left_value, comparison_attr, + LivenessKind::kLiveIfNotEqualOrUnknown}, + .pred_values = mlir::ValueRange{left_value}, + .succ_values = logical_expr->getResults(), + }); + MaybeEmplaceCfgEdges({ + .from = Before(logical_expr), + .to = Before(logical_expr.getRight()), + .owner = logical_expr, + .liveness_info = std::tuple{left_value, comparison_attr, + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = After(logical_expr.getRight()), + .to = After(logical_expr), + .owner = logical_expr, + .pred_values = GetExprRegionEndValues, + .succ_values = logical_expr->getResults(), + }); } // ┌─────◄ @@ -1330,36 +1284,42 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └──► if (auto conditional_expr = llvm::dyn_cast(op); conditional_expr != nullptr) { - mlir::ProgramPoint *before_conditional_expr = - getProgramPointBefore(conditional_expr); - mlir::ProgramPoint *after_conditional_expr = - getProgramPointAfter(conditional_expr); - mlir::ProgramPoint *before_consequent = - getProgramPointBefore(&conditional_expr.getConsequent().front()); - mlir::ProgramPoint *before_alternate = - getProgramPointBefore(&conditional_expr.getAlternate().front()); - - MaybeEmplaceCfgEdge(before_conditional_expr, before_consequent, - conditional_expr, - std::tuple{conditional_expr.getTest(), - LivenessKind::kLiveIfTrueOrUnknown}); - MaybeEmplaceCfgEdge(before_conditional_expr, before_alternate, - conditional_expr, - std::tuple{conditional_expr.getTest(), - LivenessKind::kLiveIfFalseOrUnknown}); - MaybeEmplaceCfgEdgesFromRegion(conditional_expr.getConsequent(), - after_conditional_expr, conditional_expr, - std::nullopt, GetExprRegionEndValues, - conditional_expr->getResults()); - MaybeEmplaceCfgEdgesFromRegion(conditional_expr.getAlternate(), - after_conditional_expr, conditional_expr, - std::nullopt, GetExprRegionEndValues, - conditional_expr->getResults()); + MaybeEmplaceCfgEdges({ + .from = Before(conditional_expr), + .to = Before(conditional_expr.getConsequent()), + .owner = conditional_expr, + .liveness_info = + std::tuple{conditional_expr.getTest(), + mlir::BoolAttr::get(conditional_expr.getContext(), true), + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = Before(conditional_expr), + .to = Before(conditional_expr.getAlternate()), + .owner = conditional_expr, + .liveness_info = std::tuple{conditional_expr.getTest(), + mlir::BoolAttr::get( + conditional_expr.getContext(), false), + LivenessKind::kLiveIfEqualOrUnknown}, + }); + MaybeEmplaceCfgEdges({ + .from = After(conditional_expr.getConsequent()), + .to = After(conditional_expr), + .owner = conditional_expr, + .pred_values = GetExprRegionEndValues, + .succ_values = conditional_expr->getResults(), + }); + MaybeEmplaceCfgEdges({ + .from = After(conditional_expr.getAlternate()), + .to = After(conditional_expr), + .owner = conditional_expr, + .pred_values = GetExprRegionEndValues, + .succ_values = conditional_expr->getResults(), + }); } if (auto break_stmt = llvm::dyn_cast(op); break_stmt != nullptr) { - mlir::ProgramPoint *before_break_stmt = getProgramPointBefore(break_stmt); absl::StatusOr break_target; JsirIdentifierAttr label = break_stmt.getLabelAttr(); @@ -1370,15 +1330,16 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( } if (break_target.ok()) { - MaybeEmplaceCfgEdge(before_break_stmt, break_target.value(), break_stmt, - std::nullopt); + MaybeEmplaceCfgEdges({ + .from = Before(break_stmt), + .to = {break_target.value()}, + .owner = break_stmt, + }); } } if (auto continue_stmt = llvm::dyn_cast(op); continue_stmt != nullptr) { - mlir::ProgramPoint *before_continue_stmt = - getProgramPointBefore(continue_stmt); absl::StatusOr continue_target; JsirIdentifierAttr label = continue_stmt.getLabelAttr(); @@ -1389,8 +1350,11 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( } if (continue_target.ok()) { - MaybeEmplaceCfgEdge(before_continue_stmt, continue_target.value(), - continue_stmt, std::nullopt); + MaybeEmplaceCfgEdges({ + .from = Before(continue_stmt), + .to = {continue_target.value()}, + .owner = continue_stmt, + }); } } @@ -1409,26 +1373,128 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( // └─────► if (auto try_stmt = llvm::dyn_cast(op); try_stmt != nullptr) { - mlir::ProgramPoint *before_try_stmt = getProgramPointBefore(try_stmt); - mlir::ProgramPoint *after_try_stmt = getProgramPointAfter(try_stmt); - mlir::ProgramPoint *before_block = - getProgramPointBefore(&try_stmt.getBlock().front()); - - MaybeEmplaceCfgEdge(before_try_stmt, before_block, try_stmt, std::nullopt); + MaybeEmplaceCfgEdges({ + .from = Before(try_stmt), + .to = Before(try_stmt.getBlock()), + .owner = try_stmt, + }); if (!try_stmt.getFinalizer().empty()) { - MaybeEmplaceCfgEdgesBetweenRegions(try_stmt.getBlock(), - try_stmt.getFinalizer(), try_stmt); - MaybeEmplaceCfgEdgesFromRegion(try_stmt.getFinalizer(), after_try_stmt, - try_stmt, std::nullopt); + MaybeEmplaceCfgEdges({ + .from = After(try_stmt.getBlock()), + .to = Before(try_stmt.getFinalizer()), + .owner = try_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(try_stmt.getFinalizer()), + .to = After(try_stmt), + .owner = try_stmt, + }); } else { - MaybeEmplaceCfgEdgesFromRegion(try_stmt.getBlock(), after_try_stmt, - try_stmt, std::nullopt); + MaybeEmplaceCfgEdges({ + .from = After(try_stmt.getBlock()), + .to = After(try_stmt), + .owner = try_stmt, + }); + } + } + + // ┌─────◄ + // │ jshir.switch_statement ( + // └─────► ┌───────────────┐ + // │ cases region │ + // ┌─────◄ └───────────────┘ + // │ ); + // └─────► + if (auto switch_stmt = llvm::dyn_cast(op); + switch_stmt != nullptr) { + maybe_jump_targets = { + .labeled_break_target = getProgramPointAfter(switch_stmt), + .unlabeled_break_target = getProgramPointAfter(switch_stmt), + .continue_target = std::nullopt, + }; + MaybeEmplaceCfgEdges({ + .from = Before(switch_stmt), + .to = Before(switch_stmt.getCases()), + .owner = switch_stmt, + }); + MaybeEmplaceCfgEdges({ + .from = After(switch_stmt.getCases()), + .to = After(switch_stmt), + .owner = switch_stmt, + }); + } + + // ┌─────◄ + // │ jshir.switch_case ( + // ├─────► ┌───────────────┐ + // │ │ test region │ + // │ ┌──◄ └───────────────┘ + // └──┴──► ┌───────────────┐ + // │ consequent │ + // ┌─────◄ └───────────────┘ + // │ ); + // └─────► + if (auto switch_case = llvm::dyn_cast(op); + switch_case != nullptr) { + if (switch_case.getTest().empty()) { + MaybeEmplaceCfgEdges({ + .from = Before(switch_case), + .to = Before(switch_case.getConsequent()), + .owner = switch_case, + }); + } else { + MaybeEmplaceCfgEdges({ + .from = Before(switch_case), + .to = Before(switch_case.getTest()), + .owner = switch_case, + }); + MaybeEmplaceCfgEdges({ + .from = After(switch_case.getTest()), + .to = Before(switch_case.getConsequent()), + .owner = switch_case, + .liveness_info = + std::tuple{ + switch_case->getParentOfType() + .getDiscriminant(), + GetExprRegionEndValuesFromRegion(switch_case.getTest())[0], + LivenessKind::kLiveIfEqualOrUnknown}, + }); + + MaybeEmplaceCfgEdges({ + .from = After(switch_case.getTest()), + .to = After(switch_case), + .owner = switch_case, + .liveness_info = + std::tuple{ + switch_case->getParentOfType() + .getDiscriminant(), + GetExprRegionEndValuesFromRegion(switch_case.getTest())[0], + LivenessKind::kLiveIfNotEqualOrUnknown}, + }); + } + + // If this is not the last case, we need fall-through to the next case. + if (auto *next_node = switch_case->getNextNode(); next_node != nullptr) { + if (auto successor_case = llvm::dyn_cast(next_node); + successor_case != nullptr) { + MaybeEmplaceCfgEdges({ + .from = After(switch_case.getConsequent()), + .to = Before(successor_case.getConsequent()), + .owner = switch_case, + }); + } + } else { + MaybeEmplaceCfgEdges({ + .from = After(switch_case.getConsequent()), + .to = After(switch_case), + .owner = switch_case, + }); } } // Get optional jump targets and label to be used during recursive // initialization. These variables use RAII. - auto with_jump_targets = WithJumpTargets(op); + auto with_jump_targets = WithJumpTargets(maybe_jump_targets); auto with_label = WithLabel(op); // Recursively initialize. @@ -1441,8 +1507,8 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::initialize( return mlir::success(); } -template -void JsirDenseDataFlowAnalysis::InitializeBlock( +template +void JsirDataFlowAnalysis::InitializeBlock( mlir::Block *block) { // Initialize all inner ops. for (mlir::Operation &op : *block) { @@ -1482,9 +1548,9 @@ void JsirDenseDataFlowAnalysis::InitializeBlock( } } -template -void JsirDenseDataFlowAnalysis::InitializeBlockDependencies( - mlir::Block *block) { +template +void JsirDataFlowAnalysis:: + InitializeBlockDependencies(mlir::Block *block) { if constexpr (direction == DataflowDirection::kForward) { // For each block, we should update its successor blocks when the state // at the end of the block updates. Thus, we enumerate each predecessor's @@ -1504,8 +1570,8 @@ void JsirDenseDataFlowAnalysis::InitializeBlockDependencies( } } -template -mlir::LogicalResult JsirDenseDataFlowAnalysis::visit( +template +mlir::LogicalResult JsirDataFlowAnalysis::visit( mlir::ProgramPoint *point) { if (!point->isBlockStart()) { VisitOp(point->getPrevOp()); @@ -1515,8 +1581,8 @@ mlir::LogicalResult JsirDenseDataFlowAnalysis::visit( return mlir::success(); } -template -void JsirDenseDataFlowAnalysis::VisitOp( +template +void JsirDataFlowAnalysis::VisitOp( mlir::Operation *op) { if constexpr (direction == DataflowDirection::kForward) { JsirStateRef before_state_ref = GetStateBefore(op); @@ -1535,33 +1601,16 @@ void JsirDenseDataFlowAnalysis::VisitOp( } } -template -void JsirDenseDataFlowAnalysis::VisitBlock( +template +void JsirDataFlowAnalysis::VisitBlock( mlir::Block *block) { for (auto *edge : block_to_cfg_edges_[block]) { VisitCfgEdge(edge); } } -template -void JsirDenseDataFlowAnalysis::VisitCfgEdge( - JsirGeneralCfgEdge *edge) { - JsirStateRef pred_state_ref = GetStateImpl(edge->getPred()); - JsirStateRef succ_state_ref = GetStateImpl(edge->getSucc()); - - if constexpr (direction == DataflowDirection::kForward) { - // Merge the predecessor into the successor. - pred_state_ref.AddDependent(edge->getSucc()); - succ_state_ref.Join(pred_state_ref.value()); - } else if constexpr (direction == DataflowDirection::kBackward) { - // Merge the successor into the predecessor. - succ_state_ref.AddDependent(edge->getPred()); - pred_state_ref.Join(succ_state_ref.value()); - } -} - -template -void JsirDenseDataFlowAnalysis::PrintAtBlockEntry( +template +void JsirDataFlowAnalysis::PrintAtBlockEntry( mlir::Block &block, size_t num_indents, llvm::raw_ostream &os) { os.indent(num_indents + 2); os << "// "; @@ -1569,24 +1618,10 @@ void JsirDenseDataFlowAnalysis::PrintAtBlockEntry( os << "\n"; } -template -void JsirDenseDataFlowAnalysis::PrintAfterOp( - mlir::Operation *op, size_t num_indents, mlir::AsmState &asm_state, - llvm::raw_ostream &os) { - os << "\n"; - os.indent(num_indents + 2); - os << "// "; - GetStateAfter(op).value().print(os); -} - -// ============================================================================= -// JsirDataFlowAnalysis -// ============================================================================= - template JsirStateRef JsirDataFlowAnalysis::GetStateAt(mlir::Value value) { - return Base::template GetStateImpl(value); + return GetStateImpl(value); } template @@ -1604,19 +1639,10 @@ void JsirDataFlowAnalysis::PrintAfterOp( result_state_ref.value().print(os); } - Base::PrintAfterOp(op, num_indents, asm_state, os); -} - -template -mlir::LogicalResult JsirDataFlowAnalysis::initialize( - mlir::Operation *op) { - // The op depends on its input operands. - for (mlir::Value operand : op->getOperands()) { - JsirStateRef operand_state_ref = GetStateAt(operand); - operand_state_ref.AddDependent(Base::getProgramPointAfter(op)); - } - - return Base::initialize(op); + os << "\n"; + os.indent(num_indents + 2); + os << "// "; + GetStateAfter(op).value().print(os); } template @@ -1691,7 +1717,18 @@ void JsirDataFlowAnalysis::VisitCfgEdge( } } - Base::VisitCfgEdge(edge); + JsirStateRef pred_state_ref = GetStateImpl(edge->getPred()); + JsirStateRef succ_state_ref = GetStateImpl(edge->getSucc()); + + if constexpr (direction == DataflowDirection::kForward) { + // Merge the predecessor into the successor. + pred_state_ref.AddDependent(edge->getSucc()); + succ_state_ref.Join(pred_state_ref.value()); + } else if constexpr (direction == DataflowDirection::kBackward) { + // Merge the successor into the predecessor. + succ_state_ref.AddDependent(edge->getPred()); + pred_state_ref.Join(succ_state_ref.value()); + } } } // namespace maldoca diff --git a/maldoca/js/ir/jsir_gen.cc b/maldoca/js/ir/jsir_gen.cc index d8a82e0..59884a5 100644 --- a/maldoca/js/ir/jsir_gen.cc +++ b/maldoca/js/ir/jsir_gen.cc @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/maldoca/js/ir/jsir_gen_lib.cc b/maldoca/js/ir/jsir_gen_lib.cc index 720c7f3..78f57eb 100644 --- a/maldoca/js/ir/jsir_gen_lib.cc +++ b/maldoca/js/ir/jsir_gen_lib.cc @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/maldoca/js/ir/jsir_ops.generated.td b/maldoca/js/ir/jsir_ops.generated.td index eed04be..2814e05 100644 --- a/maldoca/js/ir/jsir_ops.generated.td +++ b/maldoca/js/ir/jsir_ops.generated.td @@ -19,12 +19,12 @@ #ifndef MALDOCA_JS_IR_JSIR_OPS_GENERATED_TD_ #define MALDOCA_JS_IR_JSIR_OPS_GENERATED_TD_ -include "mlir/IR/OpBase.td" -include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" include "maldoca/js/ir/interfaces.td" include "maldoca/js/ir/jsir_dialect.td" include "maldoca/js/ir/jsir_types.td" diff --git a/maldoca/js/ir/transforms/dynamic_constant_propagation/tests/variable_inline_second/README.generated.md b/maldoca/js/ir/transforms/dynamic_constant_propagation/tests/variable_inline_second/README.generated.md index e671340..71ee896 100644 --- a/maldoca/js/ir/transforms/dynamic_constant_propagation/tests/variable_inline_second/README.generated.md +++ b/maldoca/js/ir/transforms/dynamic_constant_propagation/tests/variable_inline_second/README.generated.md @@ -2,6 +2,6 @@ To run manually: ```shell bazel run //maldoca/js/ir:jsir_gen -- \ - --input_file $(pwd)/maldoca/js/ir/transforms/dynamic_constant_propagation/tests/variable_inline.second/input.js \ + --input_file $(pwd)/maldoca/js/ir/transforms/dynamic_constant_propagation/tests/variable_inline_second/input.js \ --passes "source2ast,extract_prelude,erase_comments,ast2hir,hir2lir,dynconstprop,lir2hir,hir2ast,ast2source" ```