diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 80cdf59f496..393d71f5f3f 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -107,14 +107,14 @@ struct ExtensionSet::Impl { std::unordered_map types_, functions_; }; -ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) +ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry) : registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {} Result ExtensionSet::Make(std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { ExtensionSet set; set.registry_ = registry; @@ -210,158 +210,253 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } -ExtensionIdRegistry* default_extension_id_registry() { - static struct Impl : ExtensionIdRegistry { - Impl() { - struct TypeName { - std::shared_ptr type; - util::string_view name; - }; - - // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; - // see ARROW-15535. - for (TypeName e : { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/true)); - } - - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/false)); - } - - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); - } - } +namespace { + +struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + virtual ~ExtensionIdRegistryImpl() {} - std::vector Uris() const override { - return {uris_.begin(), uris_.end()}; + std::vector Uris() const override { + return {uris_.begin(), uris_.end()}; + } + + util::optional GetType(const DataType& type) const override { + if (auto index = GetIndex(type_to_index_, &type)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - util::optional GetType(const DataType& type) const override { - if (auto index = GetIndex(type_to_index_, &type)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; + util::optional GetType(Id id, bool is_variation) const override { + if (auto index = GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - util::optional GetType(Id id, bool is_variation) const override { - if (auto index = - GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; + Status CanRegisterType(Id id, std::shared_ptr type, + bool is_variation) const override { + auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; + if (id_to_index.find(id) != id_to_index.end()) { + return Status::Invalid("Type id was already registered"); + } + if (type_to_index_.find(&*type) != type_to_index_.end()) { + return Status::Invalid("Type was already registered"); } + return Status::OK(); + } - Status RegisterType(Id id, std::shared_ptr type, - bool is_variation) override { - DCHECK_EQ(type_ids_.size(), types_.size()); - DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { + DCHECK_EQ(type_ids_.size(), types_.size()); + DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; - auto index = static_cast(type_ids_.size()); + auto index = static_cast(type_ids_.size()); - auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; - auto it_success = id_to_index->emplace(copied_id, index); + auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; + auto it_success = id_to_index->emplace(copied_id, index); - if (!it_success.second) { - return Status::Invalid("Type id was already registered"); - } + if (!it_success.second) { + return Status::Invalid("Type id was already registered"); + } - if (!type_to_index_.emplace(type.get(), index).second) { - id_to_index->erase(it_success.first); - return Status::Invalid("Type was already registered"); - } + if (!type_to_index_.emplace(type.get(), index).second) { + id_to_index->erase(it_success.first); + return Status::Invalid("Type was already registered"); + } + + type_ids_.push_back(copied_id); + types_.push_back(std::move(type)); + type_is_variation_.push_back(is_variation); + return Status::OK(); + } - type_ids_.push_back(copied_id); - types_.push_back(std::move(type)); - type_is_variation_.push_back(is_variation); - return Status::OK(); + util::optional GetFunction( + util::string_view arrow_function_name) const override { + if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; } + return {}; + } + + util::optional GetFunction(Id id) const override { + if (auto index = GetIndex(function_id_to_index_, id)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + } + return {}; + } + + Status CanRegisterFunction(Id id, std::string arrow_function_name) const override { + if (function_id_to_index_.find(id) != function_id_to_index_.end()) { + return Status::Invalid("Function id was already registered"); + } + if (function_name_to_index_.find(arrow_function_name) != + function_name_to_index_.end()) { + return Status::Invalid("Function name was already registered"); + } + return Status::OK(); + } + + Status RegisterFunction(Id id, std::string arrow_function_name) override { + DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; + + const std::string& copied_function_name{ + *function_names_.emplace(std::move(arrow_function_name)).first}; + + auto index = static_cast(function_ids_.size()); + + auto it_success = function_id_to_index_.emplace(copied_id, index); - util::optional GetFunction( - util::string_view arrow_function_name) const override { - if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + if (!it_success.second) { + return Status::Invalid("Function id was already registered"); } - util::optional GetFunction(Id id) const override { - if (auto index = GetIndex(function_id_to_index_, id)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + if (!function_name_to_index_.emplace(copied_function_name, index).second) { + function_id_to_index_.erase(it_success.first); + return Status::Invalid("Function name was already registered"); } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + function_name_ptrs_.push_back(&copied_function_name); + function_ids_.push_back(copied_id); + return Status::OK(); + } - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + // owning storage of uris, names, (arrow::)function_names, types + // note that storing strings like this is safe since references into an + // unordered_set are not invalidated on insertion + std::unordered_set uris_, names_, function_names_; + DataTypeVector types_; + std::vector type_is_variation_; + + // non-owning lookup helpers + std::vector type_ids_, function_ids_; + std::unordered_map id_to_index_, variation_id_to_index_; + std::unordered_map type_to_index_; + + std::vector function_name_ptrs_; + std::unordered_map function_id_to_index_; + std::unordered_map + function_name_to_index_; +}; - const std::string& copied_function_name{ - *function_names_.emplace(std::move(arrow_function_name)).first}; +struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { + explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} - auto index = static_cast(function_ids_.size()); + virtual ~NestedExtensionIdRegistryImpl() {} - auto it_success = function_id_to_index_.emplace(copied_id, index); + std::vector Uris() const override { + std::vector uris = parent_->Uris(); + std::unordered_set uri_set; + uri_set.insert(uris.begin(), uris.end()); + uri_set.insert(uris_.begin(), uris_.end()); + return std::vector(uris); + } - if (!it_success.second) { - return Status::Invalid("Function id was already registered"); - } + util::optional GetType(const DataType& type) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(type); + if (type_opt) { + return type_opt; + } + return parent_->GetType(type); + } + + util::optional GetType(Id id, bool is_variation) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(id, is_variation); + if (type_opt) { + return type_opt; + } + return parent_->GetType(id, is_variation); + } - if (!function_name_to_index_.emplace(copied_function_name, index).second) { - function_id_to_index_.erase(it_success.first); - return Status::Invalid("Function name was already registered"); - } + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { + return parent_->CanRegisterType(id, type, is_variation) & + ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); + } - function_name_ptrs_.push_back(&copied_function_name); - function_ids_.push_back(copied_id); - return Status::OK(); + util::optional GetFunction( + util::string_view arrow_function_name) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); + if (func_opt) { + return func_opt; } + return parent_->GetFunction(arrow_function_name); + } - // owning storage of uris, names, (arrow::)function_names, types - // note that storing strings like this is safe since references into an - // unordered_set are not invalidated on insertion - std::unordered_set uris_, names_, function_names_; - DataTypeVector types_; - std::vector type_is_variation_; + util::optional GetFunction(Id id) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(id); + } - // non-owning lookup helpers - std::vector type_ids_, function_ids_; - std::unordered_map id_to_index_, variation_id_to_index_; - std::unordered_map type_to_index_; + Status RegisterFunction(Id id, std::string arrow_function_name) override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + } - std::vector function_name_ptrs_; - std::unordered_map function_id_to_index_; - std::unordered_map - function_name_to_index_; - } impl_; + const ExtensionIdRegistry* parent_; +}; +struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { + DefaultExtensionIdRegistry() { + struct TypeName { + std::shared_ptr type; + util::string_view name; + }; + + // The type (variation) mappings listed below need to be kept in sync + // with the YAML at substrait/format/extension_types.yaml manually; + // see ARROW-15535. + for (TypeName e : { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/true)); + } + + for (TypeName e : { + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/false)); + } + + // TODO: this is just a placeholder right now. We'll need a YAML file for + // all functions (and prototypes) that Arrow provides that are relevant + // for Substrait, and include mappings for all of them here. See + // ARROW-15535. + for (util::string_view name : { + "add", + }) { + DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + } + } +}; + +} // namespace + +ExtensionIdRegistry* default_extension_id_registry() { + static DefaultExtensionIdRegistry impl_; return &impl_; } +std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent) { + return std::make_shared(parent); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 951f7ffa3a1..a6019333fc0 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -63,6 +63,8 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { }; virtual util::optional GetType(const DataType&) const = 0; virtual util::optional GetType(Id, bool is_variation) const = 0; + virtual Status CanRegisterType(Id, std::shared_ptr type, + bool is_variation) const = 0; virtual Status RegisterType(Id, std::shared_ptr, bool is_variation) = 0; /// \brief A mapping between a Substrait ID and an Arrow function @@ -84,6 +86,7 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { virtual util::optional GetFunction(Id) const = 0; virtual util::optional GetFunction( util::string_view arrow_function_name) const = 0; + virtual Status CanRegisterFunction(Id, std::string arrow_function_name) const = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; }; @@ -96,6 +99,19 @@ constexpr util::string_view kArrowExtTypesUri = /// Note: Function support is currently very minimal, see ARROW-15538 ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry(); +/// \brief Makes a nested registry with a given parent. +/// +/// A nested registry supports registering types and functions other and on top of those +/// already registered in its parent registry. No conflicts in IDs and names used for +/// lookup are allowed. Normally, the given parent is the default registry. +/// +/// One use case for a nested registry is for dynamic registration of functions defined +/// within a Substrait plan while keeping these registrations specific to the plan. When +/// the Substrait plan is disposed of, normally after its execution, the nested registry +/// can be disposed of as well. +ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent); + /// \brief A set of extensions used within a plan /// /// Each time an extension is used within a Substrait plan the extension @@ -140,7 +156,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { }; /// Construct an empty ExtensionSet to be populated during serialization. - explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry()); + explicit ExtensionSet(const ExtensionIdRegistry* = default_extension_id_registry()); ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet); /// Construct an ExtensionSet with explicit extension ids for efficient referencing @@ -160,7 +176,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { static Result Make( std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* = default_extension_id_registry()); + const ExtensionIdRegistry* = default_extension_id_registry()); // index in these vectors == value of _anchor/_reference fields /// TODO(ARROW-15583) this assumes that _anchor/_references won't be huge, which is not @@ -224,7 +240,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { std::size_t num_functions() const { return functions_.size(); } private: - ExtensionIdRegistry* registry_; + const ExtensionIdRegistry* registry_; /// The subset of extension registry URIs referenced by this extension set std::vector uris_; std::vector types_; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 5813dcde24c..8b3d53ae794 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -111,7 +111,7 @@ void SetElement(size_t i, const Element& element, std::vector* vector) { } // namespace Result GetExtensionSetFromPlan(const substrait::Plan& plan, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { std::vector uris; for (const auto& uri : plan.extension_uris()) { SetElement(uri.extension_uri_anchor(), uri.uri(), &uris); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 281cab0c0f3..dce23cdceba 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -49,7 +49,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( const substrait::Plan& plan, - ExtensionIdRegistry* registry = default_extension_id_registry()); + const ExtensionIdRegistry* registry = default_extension_id_registry()); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index bc2aa36856e..2ae3771f3fb 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -123,6 +123,10 @@ Result> SerializeJsonPlan(const std::string& substrait_j return engine::internal::SubstraitFromJSON("Plan", substrait_json); } +std::shared_ptr MakeExtensionIdRegistry() { + return nested_extension_id_registry(default_extension_id_registry()); +} + } // namespace substrait } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 860a459da2f..3b17bd880f2 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -37,6 +37,8 @@ ARROW_ENGINE_EXPORT Result> ExecuteSerialized ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); +ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); + } // namespace substrait } // namespace engine