diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 7e1975d3b68..fe7c6fa8ad1 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -34,52 +34,62 @@ namespace compute { class FunctionRegistry::FunctionRegistryImpl { public: - Status AddFunction(std::shared_ptr function, bool allow_overwrite) { -#ifndef NDEBUG - // This validates docstrings extensively, so don't waste time on it - // in release builds. - RETURN_NOT_OK(function->Validate()); -#endif + explicit FunctionRegistryImpl(FunctionRegistryImpl* parent = NULLPTR) + : parent_(parent) {} + ~FunctionRegistryImpl() {} - std::lock_guard mutation_guard(lock_); + Status CanAddFunction(std::shared_ptr function, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); + } + return DoAddFunction(function, allow_overwrite, /*add=*/false); + } - const std::string& name = function->name(); - auto it = name_to_function_.find(name); - if (it != name_to_function_.end() && !allow_overwrite) { - return Status::KeyError("Already have a function registered with name: ", name); + Status AddFunction(std::shared_ptr function, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); } - name_to_function_[name] = std::move(function); - return Status::OK(); + return DoAddFunction(function, allow_overwrite, /*add=*/true); + } + + Status CanAddAlias(const std::string& target_name, const std::string& source_name) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); + } + return DoAddAlias(target_name, source_name, /*add=*/false); } Status AddAlias(const std::string& target_name, const std::string& source_name) { - std::lock_guard mutation_guard(lock_); + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); + } + return DoAddAlias(target_name, source_name, /*add=*/true); + } - auto it = name_to_function_.find(source_name); - if (it == name_to_function_.end()) { - return Status::KeyError("No function registered with name: ", source_name); + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); } - name_to_function_[target_name] = it->second; - return Status::OK(); + return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/false); } Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) { - std::lock_guard mutation_guard(lock_); - - const std::string name = options_type->type_name(); - auto it = name_to_options_type_.find(name); - if (it != name_to_options_type_.end() && !allow_overwrite) { - return Status::KeyError( - "Already have a function options type registered with name: ", name); + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); } - name_to_options_type_[name] = options_type; - return Status::OK(); + return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/true); } Result> GetFunction(const std::string& name) const { auto it = name_to_function_.find(name); if (it == name_to_function_.end()) { + if (parent_ != NULLPTR) { + return parent_->GetFunction(name); + } return Status::KeyError("No function registered with name: ", name); } return it->second; @@ -87,6 +97,9 @@ class FunctionRegistry::FunctionRegistryImpl { std::vector GetFunctionNames() const { std::vector results; + if (parent_ != NULLPTR) { + results = parent_->GetFunctionNames(); + } for (auto it : name_to_function_) { results.push_back(it.first); } @@ -98,14 +111,96 @@ class FunctionRegistry::FunctionRegistryImpl { const std::string& name) const { auto it = name_to_options_type_.find(name); if (it == name_to_options_type_.end()) { + if (parent_ != NULLPTR) { + return parent_->GetFunctionOptionsType(name); + } return Status::KeyError("No function options type registered with name: ", name); } return it->second; } - int num_functions() const { return static_cast(name_to_function_.size()); } + int num_functions() const { + return (parent_ == NULLPTR ? 0 : parent_->num_functions()) + + static_cast(name_to_function_.size()); + } private: + // must not acquire mutex + Status CanAddFunctionName(const std::string& name, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(name, allow_overwrite)); + } + if (!allow_overwrite) { + auto it = name_to_function_.find(name); + if (it != name_to_function_.end()) { + return Status::KeyError("Already have a function registered with name: ", name); + } + } + return Status::OK(); + } + + // must not acquire mutex + Status CanAddOptionsTypeName(const std::string& name, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddOptionsTypeName(name, allow_overwrite)); + } + if (!allow_overwrite) { + auto it = name_to_options_type_.find(name); + if (it != name_to_options_type_.end()) { + return Status::KeyError( + "Already have a function options type registered with name: ", name); + } + } + return Status::OK(); + } + + Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, + bool add) { +#ifndef NDEBUG + // This validates docstrings extensively, so don't waste time on it + // in release builds. + RETURN_NOT_OK(function->Validate()); +#endif + + std::lock_guard mutation_guard(lock_); + + const std::string& name = function->name(); + RETURN_NOT_OK(CanAddFunctionName(name, allow_overwrite)); + if (add) { + name_to_function_[name] = std::move(function); + } + return Status::OK(); + } + + Status DoAddAlias(const std::string& target_name, const std::string& source_name, + bool add) { + // source name must exist in this registry or the parent + // check outside mutex, in case GetFunction leads to mutex acquisition + ARROW_ASSIGN_OR_RAISE(auto func, GetFunction(source_name)); + + std::lock_guard mutation_guard(lock_); + + // target name must be available in this registry and the parent + RETURN_NOT_OK(CanAddFunctionName(target_name, /*allow_overwrite=*/false)); + if (add) { + name_to_function_[target_name] = func; + } + return Status::OK(); + } + + Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite, bool add) { + std::lock_guard mutation_guard(lock_); + + const std::string name = options_type->type_name(); + RETURN_NOT_OK(CanAddOptionsTypeName(name, /*allow_overwrite=*/false)); + if (add) { + name_to_options_type_[options_type->type_name()] = options_type; + } + return Status::OK(); + } + + FunctionRegistryImpl* parent_; std::mutex lock_; std::unordered_map> name_to_function_; std::unordered_map name_to_options_type_; @@ -115,20 +210,42 @@ std::unique_ptr FunctionRegistry::Make() { return std::unique_ptr(new FunctionRegistry()); } -FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); } +std::unique_ptr FunctionRegistry::Make(FunctionRegistry* parent) { + return std::unique_ptr(new FunctionRegistry( + new FunctionRegistry::FunctionRegistryImpl(parent->impl_.get()))); +} + +FunctionRegistry::FunctionRegistry() : FunctionRegistry(new FunctionRegistryImpl()) {} + +FunctionRegistry::FunctionRegistry(FunctionRegistryImpl* impl) { impl_.reset(impl); } FunctionRegistry::~FunctionRegistry() {} +Status FunctionRegistry::CanAddFunction(std::shared_ptr function, + bool allow_overwrite) { + return impl_->CanAddFunction(std::move(function), allow_overwrite); +} + Status FunctionRegistry::AddFunction(std::shared_ptr function, bool allow_overwrite) { return impl_->AddFunction(std::move(function), allow_overwrite); } +Status FunctionRegistry::CanAddAlias(const std::string& target_name, + const std::string& source_name) { + return impl_->CanAddAlias(target_name, source_name); +} + Status FunctionRegistry::AddAlias(const std::string& target_name, const std::string& source_name) { return impl_->AddAlias(target_name, source_name); } +Status FunctionRegistry::CanAddFunctionOptionsType( + const FunctionOptionsType* options_type, bool allow_overwrite) { + return impl_->CanAddFunctionOptionsType(options_type, allow_overwrite); +} + Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite) { return impl_->AddFunctionOptionsType(options_type, allow_overwrite); diff --git a/cpp/src/arrow/compute/registry.h b/cpp/src/arrow/compute/registry.h index e83036db6ac..16a76255946 100644 --- a/cpp/src/arrow/compute/registry.h +++ b/cpp/src/arrow/compute/registry.h @@ -47,35 +47,64 @@ class ARROW_EXPORT FunctionRegistry { public: ~FunctionRegistry(); - /// \brief Construct a new registry. Most users only need to use the global - /// registry + /// \brief Construct a new registry. + /// + /// Most users only need to use the global registry. static std::unique_ptr Make(); - /// \brief Add a new function to the registry. Returns Status::KeyError if a - /// function with the same name is already registered + /// \brief Construct a new nested registry with the given parent. + /// + /// Most users only need to use the global registry. The returned registry never changes + /// its parent, even when an operation allows overwritting. + static std::unique_ptr Make(FunctionRegistry* parent); + + /// \brief Check whether a new function can be added to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. + Status CanAddFunction(std::shared_ptr function, bool allow_overwrite = false); + + /// \brief Add a new function to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. Status AddFunction(std::shared_ptr function, bool allow_overwrite = false); - /// \brief Add aliases for the given function name. Returns Status::KeyError if the - /// function with the given name is not registered + /// \brief Check whether an alias can be added for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. + Status CanAddAlias(const std::string& target_name, const std::string& source_name); + + /// \brief Add alias for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. Status AddAlias(const std::string& target_name, const std::string& source_name); - /// \brief Add a new function options type to the registry. Returns Status::KeyError if - /// a function options type with the same name is already registered + /// \brief Check whether a new function options type can be added to the registry. + /// + /// \returns Status::KeyError if a function options type with the same name is already + /// registered. + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false); + + /// \brief Add a new function options type to the registry. + /// + /// \returns Status::KeyError if a function options type with the same name is already + /// registered. Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false); - /// \brief Retrieve a function by name from the registry + /// \brief Retrieve a function by name from the registry. Result> GetFunction(const std::string& name) const; - /// \brief Return vector of all entry names in the registry. Helpful for - /// displaying a manifest of available functions + /// \brief Return vector of all entry names in the registry. + /// + /// Helpful for displaying a manifest of available functions. std::vector GetFunctionNames() const; - /// \brief Retrieve a function options type by name from the registry + /// \brief Retrieve a function options type by name from the registry. Result GetFunctionOptionsType( const std::string& name) const; - /// \brief The number of currently registered functions + /// \brief The number of currently registered functions. int num_functions() const; private: @@ -84,9 +113,13 @@ class ARROW_EXPORT FunctionRegistry { // Use PIMPL pattern to not have std::unordered_map here class FunctionRegistryImpl; std::unique_ptr impl_; + + explicit FunctionRegistry(FunctionRegistryImpl* impl); + + class NestedFunctionRegistryImpl; }; -/// \brief Return the process-global function registry +/// \brief Return the process-global function registry. ARROW_EXPORT FunctionRegistry* GetFunctionRegistry(); } // namespace compute diff --git a/cpp/src/arrow/compute/registry_test.cc b/cpp/src/arrow/compute/registry_test.cc index faf47a46f68..937515af4ac 100644 --- a/cpp/src/arrow/compute/registry_test.cc +++ b/cpp/src/arrow/compute/registry_test.cc @@ -27,37 +27,44 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/macros.h" +#include "arrow/util/make_unique.h" namespace arrow { namespace compute { -class TestRegistry : public ::testing::Test { - public: - void SetUp() { registry_ = FunctionRegistry::Make(); } +using MakeFunctionRegistry = std::function()>; +using GetNumFunctions = std::function; +using GetFunctionNames = std::function()>; +using TestRegistryParams = + std::tuple; - protected: - std::unique_ptr registry_; -}; +struct TestRegistry : public ::testing::TestWithParam {}; -TEST_F(TestRegistry, CreateBuiltInRegistry) { +TEST(TestRegistry, CreateBuiltInRegistry) { // This does DCHECK_OK internally for now so this will fail in debug builds // if there is a problem initializing the global function registry FunctionRegistry* registry = GetFunctionRegistry(); ARROW_UNUSED(registry); } -TEST_F(TestRegistry, Basics) { - ASSERT_EQ(0, registry_->num_functions()); +TEST_P(TestRegistry, Basics) { + auto registry_factory = std::get<0>(GetParam()); + auto registry_ = registry_factory(); + auto get_num_funcs = std::get<1>(GetParam()); + int n_funcs = get_num_funcs(); + auto get_func_names = std::get<2>(GetParam()); + std::vector func_names = get_func_names(); + ASSERT_EQ(n_funcs, registry_->num_functions()); std::shared_ptr func = std::make_shared( "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); ASSERT_OK(registry_->AddFunction(func)); - ASSERT_EQ(1, registry_->num_functions()); + ASSERT_EQ(n_funcs + 1, registry_->num_functions()); func = std::make_shared("f0", Arity::Binary(), /*doc=*/FunctionDoc::Empty()); ASSERT_OK(registry_->AddFunction(func)); - ASSERT_EQ(2, registry_->num_functions()); + ASSERT_EQ(n_funcs + 2, registry_->num_functions()); ASSERT_OK_AND_ASSIGN(std::shared_ptr f1, registry_->GetFunction("f1")); ASSERT_EQ("f1", f1->name()); @@ -75,7 +82,11 @@ TEST_F(TestRegistry, Basics) { ASSERT_OK_AND_ASSIGN(f1, registry_->GetFunction("f1")); ASSERT_EQ(Function::SCALAR_AGGREGATE, f1->kind()); - std::vector expected_names = {"f0", "f1"}; + std::vector expected_names(func_names); + for (auto name : {"f0", "f1"}) { + expected_names.push_back(name); + } + std::sort(expected_names.begin(), expected_names.end()); ASSERT_EQ(expected_names, registry_->GetFunctionNames()); // Aliases @@ -85,5 +96,139 @@ TEST_F(TestRegistry, Basics) { ASSERT_EQ(func, f2); } +INSTANTIATE_TEST_SUITE_P( + TestRegistry, TestRegistry, + testing::Values( + std::make_tuple( + static_cast([]() { return FunctionRegistry::Make(); }), + []() { return 0; }, []() { return std::vector{}; }, "default"), + std::make_tuple( + static_cast([]() { + return FunctionRegistry::Make(GetFunctionRegistry()); + }), + []() { return GetFunctionRegistry()->num_functions(); }, + []() { return GetFunctionRegistry()->GetFunctionNames(); }, "nested"))); + +TEST(TestRegistry, RegisterTempFunctions) { + auto default_registry = GetFunctionRegistry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (std::string func_name : {"f1", "f2"}) { + std::shared_ptr func = std::make_shared( + func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + ASSERT_OK(registry->CanAddFunction(func)); + ASSERT_OK(registry->AddFunction(func)); + ASSERT_RAISES(KeyError, registry->CanAddFunction(func)); + ASSERT_RAISES(KeyError, registry->AddFunction(func)); + ASSERT_OK(default_registry->CanAddFunction(func)); + } + } +} + +TEST(TestRegistry, RegisterTempAliases) { + auto default_registry = GetFunctionRegistry(); + std::vector func_names = default_registry->GetFunctionNames(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (std::string func_name : func_names) { + std::string alias_name = "alias_of_" + func_name; + std::shared_ptr func = std::make_shared( + func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + ASSERT_RAISES(KeyError, registry->GetFunction(alias_name)); + ASSERT_OK(registry->CanAddAlias(alias_name, func_name)); + ASSERT_OK(registry->AddAlias(alias_name, func_name)); + ASSERT_OK(registry->GetFunction(alias_name)); + ASSERT_OK(default_registry->GetFunction(func_name)); + ASSERT_RAISES(KeyError, default_registry->GetFunction(alias_name)); + } + } +} + +template +class ExampleOptions : public FunctionOptions { + public: + explicit ExampleOptions(std::shared_ptr value); + std::shared_ptr value; +}; + +template +class ExampleOptionsType : public FunctionOptionsType { + public: + static const FunctionOptionsType* GetInstance() { + static std::unique_ptr instance( + new ExampleOptionsType()); + return instance.get(); + } + const char* type_name() const override { + static std::string name = std::string("example") + std::to_string(kExampleSeqNum); + return name.c_str(); + } + std::string Stringify(const FunctionOptions& options) const override { + return type_name(); + } + bool Compare(const FunctionOptions& options, + const FunctionOptions& other) const override { + return true; + } + std::unique_ptr Copy(const FunctionOptions& options) const override { + const auto& opts = static_cast&>(options); + return arrow::internal::make_unique>(opts.value); + } +}; +template +ExampleOptions::ExampleOptions(std::shared_ptr value) + : FunctionOptions(ExampleOptionsType::GetInstance()), + value(std::move(value)) {} + +TEST(TestRegistry, RegisterTempFunctionOptionsType) { + auto default_registry = GetFunctionRegistry(); + std::vector options_types = { + ExampleOptionsType<1>::GetInstance(), + ExampleOptionsType<2>::GetInstance(), + }; + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (auto options_type : options_types) { + ASSERT_OK(registry->CanAddFunctionOptionsType(options_type)); + ASSERT_OK(registry->AddFunctionOptionsType(options_type)); + ASSERT_RAISES(KeyError, registry->CanAddFunctionOptionsType(options_type)); + ASSERT_RAISES(KeyError, registry->AddFunctionOptionsType(options_type)); + ASSERT_OK(default_registry->CanAddFunctionOptionsType(options_type)); + } + } +} + +TEST(TestRegistry, RegisterNestedFunctions) { + auto default_registry = GetFunctionRegistry(); + std::shared_ptr func1 = std::make_shared( + "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + std::shared_ptr func2 = std::make_shared( + "f2", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = FunctionRegistry::Make(default_registry); + + ASSERT_OK(registry1->CanAddFunction(func1)); + ASSERT_OK(registry1->AddFunction(func1)); + + for (int j = 0; j < rounds; j++) { + auto registry2 = FunctionRegistry::Make(registry1.get()); + + ASSERT_OK(registry2->CanAddFunction(func2)); + ASSERT_OK(registry2->AddFunction(func2)); + ASSERT_RAISES(KeyError, registry2->CanAddFunction(func2)); + ASSERT_RAISES(KeyError, registry2->AddFunction(func2)); + ASSERT_OK(default_registry->CanAddFunction(func2)); + } + + ASSERT_RAISES(KeyError, registry1->CanAddFunction(func1)); + ASSERT_RAISES(KeyError, registry1->AddFunction(func1)); + ASSERT_OK(default_registry->CanAddFunction(func1)); + } +} + } // namespace compute } // namespace arrow