From 30a8dbfe5881d4739d962837c89bcc45b7280296 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 06:56:39 -0400 Subject: [PATCH 01/11] ARROW-15635: [C++] Support nested extension-id-registry --- .../arrow/engine/substrait/extension_set.cc | 184 +++++++++++++----- .../arrow/engine/substrait/extension_set.h | 22 ++- .../arrow/engine/substrait/plan_internal.cc | 2 +- .../arrow/engine/substrait/plan_internal.h | 2 +- cpp/src/arrow/engine/substrait/util.cc | 4 + cpp/src/arrow/engine/substrait/util.h | 3 + 6 files changed, 166 insertions(+), 51 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 80cdf59f496..2b005b60d11 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -107,14 +107,14 @@ struct ExtensionSet::Impl { std::unordered_map types_, functions_; }; -ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) +ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry) : registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {} Result ExtensionSet::Make(std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { ExtensionSet set; set.registry_ = registry; @@ -210,49 +210,8 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } -ExtensionIdRegistry* default_extension_id_registry() { - static struct Impl : ExtensionIdRegistry { - Impl() { - struct TypeName { - std::shared_ptr type; - util::string_view name; - }; - - // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; - // see ARROW-15535. - for (TypeName e : { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/true)); - } - - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/false)); - } - - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); - } - } - +namespace { + struct ExtensionIdRegistryImpl : ExtensionIdRegistry { std::vector Uris() const override { return {uris_.begin(), uris_.end()}; } @@ -272,6 +231,18 @@ ExtensionIdRegistry* default_extension_id_registry() { return {}; } + virtual Status CanRegisterType(Id id, std::shared_ptr type, + bool is_variation) const { + auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; + if (id_to_index.find(id) != id_to_index.end()) { + return Status::Invalid("Type id was already registered"); + } + if (type_to_index_.find(&*type) != type_to_index_.end()) { + return Status::Invalid("Type was already registered"); + } + return Status::OK(); + } + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { DCHECK_EQ(type_ids_.size(), types_.size()); @@ -315,6 +286,17 @@ ExtensionIdRegistry* default_extension_id_registry() { return {}; } + virtual Status CanRegisterFunction(Id id, std::string arrow_function_name) const { + if (function_id_to_index_.find(id) == function_id_to_index_.end()) { + return Status::Invalid("Function id was already registered"); + } + if (function_name_to_index_.find(arrow_function_name) == + function_name_to_index_.end()) { + return Status::Invalid("Function name was already registered"); + } + return Status::OK(); + } + Status RegisterFunction(Id id, std::string arrow_function_name) override { DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); @@ -358,10 +340,120 @@ ExtensionIdRegistry* default_extension_id_registry() { std::unordered_map function_id_to_index_; std::unordered_map function_name_to_index_; - } impl_; + }; + + struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { + NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} + std::vector Uris() const override { + std::vector uris = parent_->Uris(); + std::unordered_set uri_set; + uri_set.insert(uris.begin(), uris.end()); + uri_set.insert(uris_.begin(), uris_.end()); + return std::vector(uris); + } + + util::optional GetType(const DataType& type) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(type); + if (type_opt) { + return type_opt; + } + return parent_->GetType(type); + } + + util::optional GetType(Id id, bool is_variation) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(id, is_variation); + if (type_opt) { + return type_opt; + } + return parent_->GetType(id, is_variation); + } + + Status RegisterType(Id id, std::shared_ptr type, + bool is_variation) override { + return parent_->CanRegisterType(id, type, is_variation) & + ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); + } + + util::optional GetFunction( + util::string_view arrow_function_name) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(arrow_function_name); + } + + util::optional GetFunction(Id id) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(id); + } + + Status RegisterFunction(Id id, std::string arrow_function_name) override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + } + + const ExtensionIdRegistry* parent_; + }; + + struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { + DefaultExtensionIdRegistry() { + struct TypeName { + std::shared_ptr type; + util::string_view name; + }; + + // The type (variation) mappings listed below need to be kept in sync + // with the YAML at substrait/format/extension_types.yaml manually; + // see ARROW-15535. + for (TypeName e : { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/true)); + } + + for (TypeName e : { + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/false)); + } + + // TODO: this is just a placeholder right now. We'll need a YAML file for + // all functions (and prototypes) that Arrow provides that are relevant + // for Substrait, and include mappings for all of them here. See + // ARROW-15535. + for (util::string_view name : { + "add", + }) { + DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + } + } + }; +} // namespace + +ExtensionIdRegistry* default_extension_id_registry() { + static DefaultExtensionIdRegistry impl_; return &impl_; } +std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry *parent) { + return std::make_shared(parent); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 951f7ffa3a1..d94a4adb578 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -63,6 +63,8 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { }; virtual util::optional GetType(const DataType&) const = 0; virtual util::optional GetType(Id, bool is_variation) const = 0; + virtual Status CanRegisterType(Id, std::shared_ptr type, + bool is_variation) const = 0; virtual Status RegisterType(Id, std::shared_ptr, bool is_variation) = 0; /// \brief A mapping between a Substrait ID and an Arrow function @@ -84,6 +86,7 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { virtual util::optional GetFunction(Id) const = 0; virtual util::optional GetFunction( util::string_view arrow_function_name) const = 0; + virtual Status CanRegisterFunction(Id, std::string arrow_function_name) const = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; }; @@ -96,6 +99,19 @@ constexpr util::string_view kArrowExtTypesUri = /// Note: Function support is currently very minimal, see ARROW-15538 ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry(); +/// \brief Makes a nested registry with a given parent. +/// +/// A nested registry supports registering types and functions other and on top of those +/// already registered in its parent registry. No conflicts in IDs and names used for +/// lookup are allowed. Normally, the given parent is the default registry. +/// +/// One use case for a nested registry is for dynamic registration of functions defined +/// within a Substrait plan while keeping these registrations specific to the plan. When +/// the Substrait plan is disposed of, normally after its execution, the nested registry +/// can be disposed of as well. +ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry *parent); + /// \brief A set of extensions used within a plan /// /// Each time an extension is used within a Substrait plan the extension @@ -140,7 +156,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { }; /// Construct an empty ExtensionSet to be populated during serialization. - explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry()); + explicit ExtensionSet(const ExtensionIdRegistry* = default_extension_id_registry()); ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet); /// Construct an ExtensionSet with explicit extension ids for efficient referencing @@ -160,7 +176,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { static Result Make( std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* = default_extension_id_registry()); + const ExtensionIdRegistry* = default_extension_id_registry()); // index in these vectors == value of _anchor/_reference fields /// TODO(ARROW-15583) this assumes that _anchor/_references won't be huge, which is not @@ -224,7 +240,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { std::size_t num_functions() const { return functions_.size(); } private: - ExtensionIdRegistry* registry_; + const ExtensionIdRegistry* registry_; /// The subset of extension registry URIs referenced by this extension set std::vector uris_; std::vector types_; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 5813dcde24c..8b3d53ae794 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -111,7 +111,7 @@ void SetElement(size_t i, const Element& element, std::vector* vector) { } // namespace Result GetExtensionSetFromPlan(const substrait::Plan& plan, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { std::vector uris; for (const auto& uri : plan.extension_uris()) { SetElement(uri.extension_uri_anchor(), uri.uri(), &uris); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 281cab0c0f3..dce23cdceba 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -49,7 +49,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( const substrait::Plan& plan, - ExtensionIdRegistry* registry = default_extension_id_registry()); + const ExtensionIdRegistry* registry = default_extension_id_registry()); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index bc2aa36856e..2ae3771f3fb 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -123,6 +123,10 @@ Result> SerializeJsonPlan(const std::string& substrait_j return engine::internal::SubstraitFromJSON("Plan", substrait_json); } +std::shared_ptr MakeExtensionIdRegistry() { + return nested_extension_id_registry(default_extension_id_registry()); +} + } // namespace substrait } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 860a459da2f..a096b98ef88 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -37,6 +37,9 @@ ARROW_ENGINE_EXPORT Result> ExecuteSerialized ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); +ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry() ; + + } // namespace substrait } // namespace engine From a8f1104b1fd7bf30aef6682707cf07a5cbb4e4b9 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 07:17:07 -0400 Subject: [PATCH 02/11] lint --- .../arrow/engine/substrait/extension_set.cc | 384 +++++++++--------- .../arrow/engine/substrait/extension_set.h | 2 +- cpp/src/arrow/engine/substrait/util.h | 2 +- 3 files changed, 195 insertions(+), 193 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 2b005b60d11..f8009fb6676 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -211,238 +211,240 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { } namespace { - struct ExtensionIdRegistryImpl : ExtensionIdRegistry { - std::vector Uris() const override { - return {uris_.begin(), uris_.end()}; - } - util::optional GetType(const DataType& type) const override { - if (auto index = GetIndex(type_to_index_, &type)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; - } +struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + std::vector Uris() const override { + return {uris_.begin(), uris_.end()}; + } - util::optional GetType(Id id, bool is_variation) const override { - if (auto index = - GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; + util::optional GetType(const DataType& type) const override { + if (auto index = GetIndex(type_to_index_, &type)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - virtual Status CanRegisterType(Id id, std::shared_ptr type, - bool is_variation) const { - auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; - if (id_to_index.find(id) != id_to_index.end()) { - return Status::Invalid("Type id was already registered"); - } - if (type_to_index_.find(&*type) != type_to_index_.end()) { - return Status::Invalid("Type was already registered"); - } - return Status::OK(); + util::optional GetType(Id id, bool is_variation) const override { + if (auto index = + GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - Status RegisterType(Id id, std::shared_ptr type, - bool is_variation) override { - DCHECK_EQ(type_ids_.size(), types_.size()); - DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); - - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + virtual Status CanRegisterType(Id id, std::shared_ptr type, + bool is_variation) const { + auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; + if (id_to_index.find(id) != id_to_index.end()) { + return Status::Invalid("Type id was already registered"); + } + if (type_to_index_.find(&*type) != type_to_index_.end()) { + return Status::Invalid("Type was already registered"); + } + return Status::OK(); + } - auto index = static_cast(type_ids_.size()); + Status RegisterType(Id id, std::shared_ptr type, + bool is_variation) override { + DCHECK_EQ(type_ids_.size(), types_.size()); + DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); - auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; - auto it_success = id_to_index->emplace(copied_id, index); + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; - if (!it_success.second) { - return Status::Invalid("Type id was already registered"); - } + auto index = static_cast(type_ids_.size()); - if (!type_to_index_.emplace(type.get(), index).second) { - id_to_index->erase(it_success.first); - return Status::Invalid("Type was already registered"); - } + auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; + auto it_success = id_to_index->emplace(copied_id, index); - type_ids_.push_back(copied_id); - types_.push_back(std::move(type)); - type_is_variation_.push_back(is_variation); - return Status::OK(); + if (!it_success.second) { + return Status::Invalid("Type id was already registered"); } - util::optional GetFunction( - util::string_view arrow_function_name) const override { - if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + if (!type_to_index_.emplace(type.get(), index).second) { + id_to_index->erase(it_success.first); + return Status::Invalid("Type was already registered"); } - util::optional GetFunction(Id id) const override { - if (auto index = GetIndex(function_id_to_index_, id)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; - } + type_ids_.push_back(copied_id); + types_.push_back(std::move(type)); + type_is_variation_.push_back(is_variation); + return Status::OK(); + } - virtual Status CanRegisterFunction(Id id, std::string arrow_function_name) const { - if (function_id_to_index_.find(id) == function_id_to_index_.end()) { - return Status::Invalid("Function id was already registered"); - } - if (function_name_to_index_.find(arrow_function_name) == - function_name_to_index_.end()) { - return Status::Invalid("Function name was already registered"); - } - return Status::OK(); + util::optional GetFunction( + util::string_view arrow_function_name) const override { + if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; } + return {}; + } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + util::optional GetFunction(Id id) const override { + if (auto index = GetIndex(function_id_to_index_, id)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + } + return {}; + } - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + virtual Status CanRegisterFunction(Id id, std::string arrow_function_name) const { + if (function_id_to_index_.find(id) == function_id_to_index_.end()) { + return Status::Invalid("Function id was already registered"); + } + if (function_name_to_index_.find(arrow_function_name) == + function_name_to_index_.end()) { + return Status::Invalid("Function name was already registered"); + } + return Status::OK(); + } - const std::string& copied_function_name{ - *function_names_.emplace(std::move(arrow_function_name)).first}; + Status RegisterFunction(Id id, std::string arrow_function_name) override { + DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); - auto index = static_cast(function_ids_.size()); + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; - auto it_success = function_id_to_index_.emplace(copied_id, index); + const std::string& copied_function_name{ + *function_names_.emplace(std::move(arrow_function_name)).first}; - if (!it_success.second) { - return Status::Invalid("Function id was already registered"); - } + auto index = static_cast(function_ids_.size()); - if (!function_name_to_index_.emplace(copied_function_name, index).second) { - function_id_to_index_.erase(it_success.first); - return Status::Invalid("Function name was already registered"); - } + auto it_success = function_id_to_index_.emplace(copied_id, index); - function_name_ptrs_.push_back(&copied_function_name); - function_ids_.push_back(copied_id); - return Status::OK(); + if (!it_success.second) { + return Status::Invalid("Function id was already registered"); } - // owning storage of uris, names, (arrow::)function_names, types - // note that storing strings like this is safe since references into an - // unordered_set are not invalidated on insertion - std::unordered_set uris_, names_, function_names_; - DataTypeVector types_; - std::vector type_is_variation_; - - // non-owning lookup helpers - std::vector type_ids_, function_ids_; - std::unordered_map id_to_index_, variation_id_to_index_; - std::unordered_map type_to_index_; - - std::vector function_name_ptrs_; - std::unordered_map function_id_to_index_; - std::unordered_map - function_name_to_index_; - }; - - struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { - NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) - : parent_(parent) {} - - std::vector Uris() const override { - std::vector uris = parent_->Uris(); - std::unordered_set uri_set; - uri_set.insert(uris.begin(), uris.end()); - uri_set.insert(uris_.begin(), uris_.end()); - return std::vector(uris); + if (!function_name_to_index_.emplace(copied_function_name, index).second) { + function_id_to_index_.erase(it_success.first); + return Status::Invalid("Function name was already registered"); } - util::optional GetType(const DataType& type) const override { - auto type_opt = ExtensionIdRegistryImpl::GetType(type); - if (type_opt) { - return type_opt; - } - return parent_->GetType(type); + function_name_ptrs_.push_back(&copied_function_name); + function_ids_.push_back(copied_id); + return Status::OK(); + } + + // owning storage of uris, names, (arrow::)function_names, types + // note that storing strings like this is safe since references into an + // unordered_set are not invalidated on insertion + std::unordered_set uris_, names_, function_names_; + DataTypeVector types_; + std::vector type_is_variation_; + + // non-owning lookup helpers + std::vector type_ids_, function_ids_; + std::unordered_map id_to_index_, variation_id_to_index_; + std::unordered_map type_to_index_; + + std::vector function_name_ptrs_; + std::unordered_map function_id_to_index_; + std::unordered_map + function_name_to_index_; +}; + +struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { + NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} + + std::vector Uris() const override { + std::vector uris = parent_->Uris(); + std::unordered_set uri_set; + uri_set.insert(uris.begin(), uris.end()); + uri_set.insert(uris_.begin(), uris_.end()); + return std::vector(uris); + } + + util::optional GetType(const DataType& type) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(type); + if (type_opt) { + return type_opt; } + return parent_->GetType(type); + } - util::optional GetType(Id id, bool is_variation) const override { - auto type_opt = ExtensionIdRegistryImpl::GetType(id, is_variation); - if (type_opt) { - return type_opt; - } - return parent_->GetType(id, is_variation); + util::optional GetType(Id id, bool is_variation) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(id, is_variation); + if (type_opt) { + return type_opt; } + return parent_->GetType(id, is_variation); + } + + Status RegisterType(Id id, std::shared_ptr type, + bool is_variation) override { + return parent_->CanRegisterType(id, type, is_variation) & + ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); + } - Status RegisterType(Id id, std::shared_ptr type, - bool is_variation) override { - return parent_->CanRegisterType(id, type, is_variation) & - ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); + util::optional GetFunction( + util::string_view arrow_function_name) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); + if (func_opt) { + return func_opt; } + return parent_->GetFunction(arrow_function_name); + } - util::optional GetFunction( - util::string_view arrow_function_name) const override { - auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); - if (func_opt) { - return func_opt; - } - return parent_->GetFunction(arrow_function_name); + util::optional GetFunction(Id id) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); + if (func_opt) { + return func_opt; } + return parent_->GetFunction(id); + } - util::optional GetFunction(Id id) const override { - auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); - if (func_opt) { - return func_opt; - } - return parent_->GetFunction(id); + Status RegisterFunction(Id id, std::string arrow_function_name) override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + } + + const ExtensionIdRegistry* parent_; +}; + +struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { + DefaultExtensionIdRegistry() { + struct TypeName { + std::shared_ptr type; + util::string_view name; + }; + + // The type (variation) mappings listed below need to be kept in sync + // with the YAML at substrait/format/extension_types.yaml manually; + // see ARROW-15535. + for (TypeName e : { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/true)); } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - return parent_->CanRegisterFunction(id, arrow_function_name) & - ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + for (TypeName e : { + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/false)); } - const ExtensionIdRegistry* parent_; - }; - - struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { - DefaultExtensionIdRegistry() { - struct TypeName { - std::shared_ptr type; - util::string_view name; - }; - - // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; - // see ARROW-15535. - for (TypeName e : { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/true)); - } - - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/false)); - } - - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); - } + // TODO: this is just a placeholder right now. We'll need a YAML file for + // all functions (and prototypes) that Arrow provides that are relevant + // for Substrait, and include mappings for all of them here. See + // ARROW-15535. + for (util::string_view name : { + "add", + }) { + DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); } - }; + } +}; + } // namespace ExtensionIdRegistry* default_extension_id_registry() { @@ -451,7 +453,7 @@ ExtensionIdRegistry* default_extension_id_registry() { } std::shared_ptr nested_extension_id_registry( - const ExtensionIdRegistry *parent) { + const ExtensionIdRegistry* parent) { return std::make_shared(parent); } diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index d94a4adb578..a6019333fc0 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -110,7 +110,7 @@ ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry(); /// the Substrait plan is disposed of, normally after its execution, the nested registry /// can be disposed of as well. ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_registry( - const ExtensionIdRegistry *parent); + const ExtensionIdRegistry* parent); /// \brief A set of extensions used within a plan /// diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index a096b98ef88..1775d9dc4ae 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -37,7 +37,7 @@ ARROW_ENGINE_EXPORT Result> ExecuteSerialized ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); -ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry() ; +ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); } // namespace substrait From e8406936fb230aa6def1438ce3817003994126b7 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 07:45:21 -0400 Subject: [PATCH 03/11] lint --- cpp/src/arrow/engine/substrait/extension_set.cc | 15 ++++++--------- cpp/src/arrow/engine/substrait/util.h | 1 - 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index f8009fb6676..df1545b79e2 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -225,15 +225,14 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { } util::optional GetType(Id id, bool is_variation) const override { - if (auto index = - GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { + if (auto index = GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } return {}; } virtual Status CanRegisterType(Id id, std::shared_ptr type, - bool is_variation) const { + bool is_variation) const { auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; if (id_to_index.find(id) != id_to_index.end()) { return Status::Invalid("Type id was already registered"); @@ -244,8 +243,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return Status::OK(); } - Status RegisterType(Id id, std::shared_ptr type, - bool is_variation) override { + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { DCHECK_EQ(type_ids_.size(), types_.size()); DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); @@ -344,8 +342,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { }; struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { - NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) - : parent_(parent) {} + NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} std::vector Uris() const override { std::vector uris = parent_->Uris(); @@ -374,7 +371,7 @@ struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { return parent_->CanRegisterType(id, type, is_variation) & - ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); + ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); } util::optional GetFunction( @@ -396,7 +393,7 @@ struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { Status RegisterFunction(Id id, std::string arrow_function_name) override { return parent_->CanRegisterFunction(id, arrow_function_name) & - ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); } const ExtensionIdRegistry* parent_; diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 1775d9dc4ae..3b17bd880f2 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -39,7 +39,6 @@ ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); - } // namespace substrait } // namespace engine From cc38b93650ed2d061fb2621c3977d6bbc7565f06 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 08:13:28 -0400 Subject: [PATCH 04/11] lint --- cpp/src/arrow/engine/substrait/extension_set.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index df1545b79e2..989ea4267ce 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -368,8 +368,7 @@ struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { return parent_->GetType(id, is_variation); } - Status RegisterType(Id id, std::shared_ptr type, - bool is_variation) override { + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { return parent_->CanRegisterType(id, type, is_variation) & ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); } From 50bc1aa92a4268bc861f5c79cae976250f9e4357 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 08:32:28 -0400 Subject: [PATCH 05/11] lint --- cpp/src/arrow/engine/substrait/extension_set.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 989ea4267ce..2c2cbc65913 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -342,7 +342,8 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { }; struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { - NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} + explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} std::vector Uris() const override { std::vector uris = parent_->Uris(); From 2e05a98cd11812d5330f04c08d0c8b185b57a116 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 08:59:48 -0400 Subject: [PATCH 06/11] Add virtual dtors to extension-id-registry impls --- cpp/src/arrow/engine/substrait/extension_set.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 2c2cbc65913..075735f1908 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -213,6 +213,8 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { namespace { struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + virtual ~ExtensionIdRegistryImpl() {} + std::vector Uris() const override { return {uris_.begin(), uris_.end()}; } @@ -345,6 +347,8 @@ struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} + virtual ~NestedExtensionIdRegistryImpl() {} + std::vector Uris() const override { std::vector uris = parent_->Uris(); std::unordered_set uri_set; From fa9e6bdfa3417623e58982b9fc10b08baf0e1c1e Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 11:31:59 -0400 Subject: [PATCH 07/11] Add override modifiers --- cpp/src/arrow/engine/substrait/extension_set.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 075735f1908..5d6b6d1f6da 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -234,7 +234,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { } virtual Status CanRegisterType(Id id, std::shared_ptr type, - bool is_variation) const { + bool is_variation) const override { auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; if (id_to_index.find(id) != id_to_index.end()) { return Status::Invalid("Type id was already registered"); @@ -287,7 +287,8 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - virtual Status CanRegisterFunction(Id id, std::string arrow_function_name) const { + virtual Status CanRegisterFunction( + Id id, std::string arrow_function_name) const override { if (function_id_to_index_.find(id) == function_id_to_index_.end()) { return Status::Invalid("Function id was already registered"); } From 5eb852f232ea66772d431431792d34aafe4d99b7 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 14:25:20 -0400 Subject: [PATCH 08/11] lint --- cpp/src/arrow/engine/substrait/extension_set.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 5d6b6d1f6da..63a65a40d78 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -287,8 +287,8 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - virtual Status CanRegisterFunction( - Id id, std::string arrow_function_name) const override { + virtual Status CanRegisterFunction(Id id, + std::string arrow_function_name) const override { if (function_id_to_index_.find(id) == function_id_to_index_.end()) { return Status::Invalid("Function id was already registered"); } From ad42d45e15ae0145df276e90a629349445aa59a8 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 14:47:57 -0400 Subject: [PATCH 09/11] lint --- cpp/src/arrow/engine/substrait/extension_set.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 63a65a40d78..a8fa7acbba1 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -233,8 +233,8 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - virtual Status CanRegisterType(Id id, std::shared_ptr type, - bool is_variation) const override { + Status CanRegisterType(Id id, std::shared_ptr type, + bool is_variation) const override { auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; if (id_to_index.find(id) != id_to_index.end()) { return Status::Invalid("Type id was already registered"); @@ -287,8 +287,8 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - virtual Status CanRegisterFunction(Id id, - std::string arrow_function_name) const override { + Status CanRegisterFunction(Id id, + std::string arrow_function_name) const override { if (function_id_to_index_.find(id) == function_id_to_index_.end()) { return Status::Invalid("Function id was already registered"); } From f3dde55dffb4e422420f647b1d1cc8213a3ca3f7 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 15:15:40 -0400 Subject: [PATCH 10/11] lint --- cpp/src/arrow/engine/substrait/extension_set.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index a8fa7acbba1..cdb34db10d1 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -287,8 +287,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - Status CanRegisterFunction(Id id, - std::string arrow_function_name) const override { + Status CanRegisterFunction(Id id, std::string arrow_function_name) const override { if (function_id_to_index_.find(id) == function_id_to_index_.end()) { return Status::Invalid("Function id was already registered"); } From d09180743e8ade4be8e21d8b5505b77a34372cd3 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 24 May 2022 05:56:29 -0400 Subject: [PATCH 11/11] bug fix --- cpp/src/arrow/engine/substrait/extension_set.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index cdb34db10d1..393d71f5f3f 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -288,10 +288,10 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { } Status CanRegisterFunction(Id id, std::string arrow_function_name) const override { - if (function_id_to_index_.find(id) == function_id_to_index_.end()) { + if (function_id_to_index_.find(id) != function_id_to_index_.end()) { return Status::Invalid("Function id was already registered"); } - if (function_name_to_index_.find(arrow_function_name) == + if (function_name_to_index_.find(arrow_function_name) != function_name_to_index_.end()) { return Status::Invalid("Function name was already registered"); }