From 022dc3878b57fc7f134cdba21e017d289c37f320 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 27 May 2022 05:06:32 -0400 Subject: [PATCH 1/8] ARROW-16677: [C++] Support nesting of function registries --- cpp/src/arrow/compute/registry.cc | 191 +++++++++++++++++++++++-- cpp/src/arrow/compute/registry.h | 31 +++- cpp/src/arrow/compute/registry_test.cc | 165 +++++++++++++++++++-- 3 files changed, 358 insertions(+), 29 deletions(-) diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 8ab83a72e5e..772146d8ec7 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -33,8 +33,18 @@ namespace arrow { namespace compute { class FunctionRegistry::FunctionRegistryImpl { - public: - Status AddFunction(std::shared_ptr function, bool allow_overwrite) { + private: + using FuncAdd = std::function)>; + + const FuncAdd kFuncAddNoOp = + [](const std::string& name, std::shared_ptr func) {}; + const FuncAdd kFuncAddDo = + [this](const std::string& name, std::shared_ptr func) { + name_to_function_[name] = func; + }; + + Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, + FuncAdd add) { #ifndef NDEBUG // This validates docstrings extensively, so don't waste time on it // in release builds. @@ -48,23 +58,56 @@ class FunctionRegistry::FunctionRegistryImpl { if (it != name_to_function_.end() && !allow_overwrite) { return Status::KeyError("Already have a function registered with name: ", name); } - name_to_function_[name] = std::move(function); + add(name, std::move(function)); return Status::OK(); } - Status AddAlias(const std::string& target_name, const std::string& source_name) { + public: + virtual Status CanAddFunction(std::shared_ptr function, + bool allow_overwrite) { + return DoAddFunction(function, allow_overwrite, kFuncAddNoOp); + } + + virtual Status AddFunction(std::shared_ptr function, bool allow_overwrite) { + return DoAddFunction(function, allow_overwrite, kFuncAddDo); + } + + private: + Status DoAddAlias(const std::string& target_name, const std::string& source_name, + FuncAdd add) { std::lock_guard mutation_guard(lock_); - auto it = name_to_function_.find(source_name); - if (it == name_to_function_.end()) { + auto func_res = GetFunction(source_name); // must not acquire the mutex + if (!func_res.ok()) { return Status::KeyError("No function registered with name: ", source_name); } - name_to_function_[target_name] = it->second; + add(target_name, func_res.ValueOrDie()); return Status::OK(); } - Status AddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) { + public: + virtual Status CanAddAlias(const std::string& target_name, + const std::string& source_name) { + return DoAddAlias(target_name, source_name, kFuncAddNoOp); + } + + virtual Status AddAlias(const std::string& target_name, + const std::string& source_name) { + return DoAddAlias(target_name, source_name, kFuncAddDo); + } + + private: + using FuncOptTypeAdd = std::function; + + const FuncOptTypeAdd kFuncOptTypeAddNoOp = + [](const FunctionOptionsType* options_type) {}; + const FuncOptTypeAdd kFuncOptTypeAddDo = + [this](const FunctionOptionsType* options_type) { + name_to_options_type_[options_type->type_name()] = options_type; + }; + + Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite, FuncOptTypeAdd add) { std::lock_guard mutation_guard(lock_); const std::string name = options_type->type_name(); @@ -73,11 +116,22 @@ class FunctionRegistry::FunctionRegistryImpl { return Status::KeyError( "Already have a function options type registered with name: ", name); } - name_to_options_type_[name] = options_type; + add(options_type); return Status::OK(); } - Result> GetFunction(const std::string& name) const { + public: + virtual Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddNoOp); + } + + virtual Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddDo); + } + + virtual Result> GetFunction(const std::string& name) const { auto it = name_to_function_.find(name); if (it == name_to_function_.end()) { return Status::KeyError("No function registered with name: ", name); @@ -85,7 +139,7 @@ class FunctionRegistry::FunctionRegistryImpl { return it->second; } - std::vector GetFunctionNames() const { + virtual std::vector GetFunctionNames() const { std::vector results; for (auto it : name_to_function_) { results.push_back(it.first); @@ -94,7 +148,7 @@ class FunctionRegistry::FunctionRegistryImpl { return results; } - Result GetFunctionOptionsType( + virtual Result GetFunctionOptionsType( const std::string& name) const { auto it = name_to_options_type_.find(name); if (it == name_to_options_type_.end()) { @@ -103,7 +157,7 @@ class FunctionRegistry::FunctionRegistryImpl { return it->second; } - int num_functions() const { return static_cast(name_to_function_.size()); } + virtual int num_functions() const { return static_cast(name_to_function_.size()); } private: std::mutex lock_; @@ -111,24 +165,131 @@ class FunctionRegistry::FunctionRegistryImpl { std::unordered_map name_to_options_type_; }; +class FunctionRegistry::NestedFunctionRegistryImpl + : public FunctionRegistry::FunctionRegistryImpl { + public: + explicit NestedFunctionRegistryImpl(FunctionRegistry::FunctionRegistryImpl* parent) + : parent_(parent) {} + + Status CanAddFunction(std::shared_ptr function, + bool allow_overwrite) override { + return parent_->CanAddFunction(function, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::CanAddFunction(function, allow_overwrite); + } + + Status AddFunction(std::shared_ptr function, bool allow_overwrite) override { + return parent_->CanAddFunction(function, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::AddFunction(function, allow_overwrite); + } + + Status CanAddAlias(const std::string& target_name, + const std::string& source_name) override { + Status st = FunctionRegistry::FunctionRegistryImpl::CanAddAlias(target_name, + source_name); + return st.ok() ? st : parent_->CanAddAlias(target_name, source_name); + } + + Status AddAlias(const std::string& target_name, + const std::string& source_name) override { + Status st = FunctionRegistry::FunctionRegistryImpl::AddAlias(target_name, + source_name); + return st.ok() ? st : parent_->AddAlias(target_name, source_name); + } + + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) override { + return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::CanAddFunctionOptionsType( + options_type, allow_overwrite); + } + + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) override { + return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::AddFunctionOptionsType( + options_type, allow_overwrite); + } + + Result> GetFunction(const std::string& name) const override { + auto func_res = FunctionRegistry::FunctionRegistryImpl::GetFunction(name); + if (func_res.ok()) { + return func_res; + } + return parent_->GetFunction(name); + } + + std::vector GetFunctionNames() const override { + auto names = parent_->GetFunctionNames(); + auto more_names = FunctionRegistry::FunctionRegistryImpl::GetFunctionNames(); + names.insert(names.end(), std::make_move_iterator(more_names.begin()), + std::make_move_iterator(more_names.end())); + return names; + } + + Result GetFunctionOptionsType( + const std::string& name) const override { + auto options_type_res = + FunctionRegistry::FunctionRegistryImpl::GetFunctionOptionsType(name); + if (options_type_res.ok()) { + return options_type_res; + } + return parent_->GetFunctionOptionsType(name); + } + + int num_functions() const override { + return parent_->num_functions() + + FunctionRegistry::FunctionRegistryImpl::num_functions(); + } + + private: + FunctionRegistry::FunctionRegistryImpl* parent_; +}; + std::unique_ptr FunctionRegistry::Make() { return std::unique_ptr(new FunctionRegistry()); } -FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); } +std::unique_ptr FunctionRegistry::Make(FunctionRegistry* parent) { + return std::unique_ptr(new FunctionRegistry( + new FunctionRegistry::NestedFunctionRegistryImpl(&*parent->impl_))); +} + +std::unique_ptr FunctionRegistry::Make( + std::unique_ptr parent) { + return FunctionRegistry::Make(&*parent); +} + +FunctionRegistry::FunctionRegistry() : FunctionRegistry(new FunctionRegistryImpl()) {} + +FunctionRegistry::FunctionRegistry(FunctionRegistryImpl* impl) { impl_.reset(impl); } FunctionRegistry::~FunctionRegistry() {} +Status FunctionRegistry::CanAddFunction(std::shared_ptr function, + bool allow_overwrite) { + return impl_->CanAddFunction(std::move(function), allow_overwrite); +} + Status FunctionRegistry::AddFunction(std::shared_ptr function, bool allow_overwrite) { return impl_->AddFunction(std::move(function), allow_overwrite); } +Status FunctionRegistry::CanAddAlias(const std::string& target_name, + const std::string& source_name) { + return impl_->CanAddAlias(target_name, source_name); +} + Status FunctionRegistry::AddAlias(const std::string& target_name, const std::string& source_name) { return impl_->AddAlias(target_name, source_name); } +Status FunctionRegistry::CanAddFunctionOptionsType( + const FunctionOptionsType* options_type, bool allow_overwrite) { + return impl_->CanAddFunctionOptionsType(options_type, allow_overwrite); +} + Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite) { return impl_->AddFunctionOptionsType(options_type, allow_overwrite); diff --git a/cpp/src/arrow/compute/registry.h b/cpp/src/arrow/compute/registry.h index e83036db6ac..de074e10d92 100644 --- a/cpp/src/arrow/compute/registry.h +++ b/cpp/src/arrow/compute/registry.h @@ -45,20 +45,42 @@ class FunctionOptionsType; /// lower-level function execution. class ARROW_EXPORT FunctionRegistry { public: - ~FunctionRegistry(); + virtual ~FunctionRegistry(); /// \brief Construct a new registry. Most users only need to use the global /// registry static std::unique_ptr Make(); + /// \brief Construct a new nested registry with the given parent. Most users only need + /// to use the global registry + static std::unique_ptr Make(FunctionRegistry* parent); + + /// \brief Construct a new nested registry with the given parent. Most users only need + /// to use the global registry + static std::unique_ptr Make(std::unique_ptr parent); + + /// \brief Checks whether a new function can be added to the registry. Returns + /// Status::KeyError if a function with the same name is already registered + Status CanAddFunction(std::shared_ptr function, bool allow_overwrite = false); + /// \brief Add a new function to the registry. Returns Status::KeyError if a /// function with the same name is already registered Status AddFunction(std::shared_ptr function, bool allow_overwrite = false); - /// \brief Add aliases for the given function name. Returns Status::KeyError if the + /// \brief Checks whether an alias can be added for the given function name. Returns + /// Status::KeyError if the function with the given name is not registered + Status CanAddAlias(const std::string& target_name, const std::string& source_name); + + /// \brief Add alias for the given function name. Returns Status::KeyError if the /// function with the given name is not registered Status AddAlias(const std::string& target_name, const std::string& source_name); + /// \brief Checks whether a new function options type can be added to the registry. + /// Returns Status::KeyError if a function options type with the same name is already + /// registered + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false); + /// \brief Add a new function options type to the registry. Returns Status::KeyError if /// a function options type with the same name is already registered Status AddFunctionOptionsType(const FunctionOptionsType* options_type, @@ -84,6 +106,11 @@ class ARROW_EXPORT FunctionRegistry { // Use PIMPL pattern to not have std::unordered_map here class FunctionRegistryImpl; std::unique_ptr impl_; + + explicit FunctionRegistry(FunctionRegistryImpl* impl); + + class NestedFunctionRegistryImpl; + friend class NestedFunctionRegistryImpl; }; /// \brief Return the process-global function registry diff --git a/cpp/src/arrow/compute/registry_test.cc b/cpp/src/arrow/compute/registry_test.cc index faf47a46f68..22a62ab73ad 100644 --- a/cpp/src/arrow/compute/registry_test.cc +++ b/cpp/src/arrow/compute/registry_test.cc @@ -27,37 +27,44 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/macros.h" +#include "arrow/util/make_unique.h" namespace arrow { namespace compute { -class TestRegistry : public ::testing::Test { - public: - void SetUp() { registry_ = FunctionRegistry::Make(); } +using MakeFunctionRegistry = std::function()>; +using GetNumFunctions = std::function; +using GetFunctionNames = std::function()>; +using TestRegistryParams = + std::tuple; - protected: - std::unique_ptr registry_; -}; +struct TestRegistry : public ::testing::TestWithParam {}; -TEST_F(TestRegistry, CreateBuiltInRegistry) { +TEST(TestRegistry, CreateBuiltInRegistry) { // This does DCHECK_OK internally for now so this will fail in debug builds // if there is a problem initializing the global function registry FunctionRegistry* registry = GetFunctionRegistry(); ARROW_UNUSED(registry); } -TEST_F(TestRegistry, Basics) { - ASSERT_EQ(0, registry_->num_functions()); +TEST_P(TestRegistry, Basics) { + auto registry_factory = std::get<0>(GetParam()); + auto registry_ = registry_factory(); + auto get_num_funcs = std::get<1>(GetParam()); + int n_funcs = get_num_funcs(); + auto get_func_names = std::get<2>(GetParam()); + std::vector func_names = get_func_names(); + ASSERT_EQ(n_funcs + 0, registry_->num_functions()); std::shared_ptr func = std::make_shared( "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); ASSERT_OK(registry_->AddFunction(func)); - ASSERT_EQ(1, registry_->num_functions()); + ASSERT_EQ(n_funcs + 1, registry_->num_functions()); func = std::make_shared("f0", Arity::Binary(), /*doc=*/FunctionDoc::Empty()); ASSERT_OK(registry_->AddFunction(func)); - ASSERT_EQ(2, registry_->num_functions()); + ASSERT_EQ(n_funcs + 2, registry_->num_functions()); ASSERT_OK_AND_ASSIGN(std::shared_ptr f1, registry_->GetFunction("f1")); ASSERT_EQ("f1", f1->name()); @@ -75,7 +82,10 @@ TEST_F(TestRegistry, Basics) { ASSERT_OK_AND_ASSIGN(f1, registry_->GetFunction("f1")); ASSERT_EQ(Function::SCALAR_AGGREGATE, f1->kind()); - std::vector expected_names = {"f0", "f1"}; + std::vector expected_names(func_names); + for (auto name : {"f0", "f1"}) { + expected_names.push_back(name); + } ASSERT_EQ(expected_names, registry_->GetFunctionNames()); // Aliases @@ -85,5 +95,136 @@ TEST_F(TestRegistry, Basics) { ASSERT_EQ(func, f2); } +INSTANTIATE_TEST_SUITE_P( + TestRegistry, TestRegistry, + testing::Values( + std::make_tuple(static_cast( + []() { return FunctionRegistry::Make(); }), []() { return 0; }, + []() { return std::vector{}; }, "default"), + std::make_tuple(static_cast( + []() { return FunctionRegistry::Make(GetFunctionRegistry()); }), + []() { return GetFunctionRegistry()->num_functions(); }, + []() { return GetFunctionRegistry()->GetFunctionNames(); }, "nested"))); + +TEST(TestRegistry, RegisterTempFunctions) { + auto default_registry = GetFunctionRegistry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (std::string func_name : {"f1", "f2"}) { + std::shared_ptr func = std::make_shared( + func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + ASSERT_OK(registry->CanAddFunction(func)); + ASSERT_OK(registry->AddFunction(func)); + ASSERT_RAISES(KeyError, registry->CanAddFunction(func)); + ASSERT_RAISES(KeyError, registry->AddFunction(func)); + ASSERT_OK(default_registry->CanAddFunction(func)); + } + } +} + +TEST(TestRegistry, RegisterTempAliases) { + auto default_registry = GetFunctionRegistry(); + std::vector func_names = default_registry->GetFunctionNames(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (std::string func_name : func_names) { + std::string alias_name = "alias_of_" + func_name; + std::shared_ptr func = std::make_shared( + func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + ASSERT_RAISES(KeyError, registry->GetFunction(alias_name)); + ASSERT_OK(registry->CanAddAlias(alias_name, func_name)); + ASSERT_OK(registry->AddAlias(alias_name, func_name)); + ASSERT_OK(registry->GetFunction(alias_name)); + ASSERT_OK(default_registry->GetFunction(func_name)); + ASSERT_RAISES(KeyError, default_registry->GetFunction(alias_name)); + } + } +} + +template +class ExampleOptions : public FunctionOptions { + public: + explicit ExampleOptions(std::shared_ptr value); + std::shared_ptr value; +}; + +template +class ExampleOptionsType : public FunctionOptionsType { + public: + static const FunctionOptionsType* GetInstance() { + static std::unique_ptr instance(new ExampleOptionsType()); + return instance.get(); + } + const char* type_name() const override { + static std::string name = std::string("example") + std::to_string(N); + return name.c_str(); + } + std::string Stringify(const FunctionOptions& options) const override { + return type_name(); + } + bool Compare(const FunctionOptions& options, + const FunctionOptions& other) const override { + return true; + } + std::unique_ptr Copy(const FunctionOptions& options) const override { + const auto& opts = static_cast&>(options); + return arrow::internal::make_unique>(opts.value); + } +}; +template +ExampleOptions::ExampleOptions(std::shared_ptr value) + : FunctionOptions(ExampleOptionsType::GetInstance()), value(std::move(value)) {} + +TEST(TestRegistry, RegisterTempFunctionOptionsType) { + auto default_registry = GetFunctionRegistry(); + std::vector options_types = { + ExampleOptionsType<1>::GetInstance(), + ExampleOptionsType<2>::GetInstance(), + }; + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (auto options_type : options_types) { + ASSERT_OK(registry->CanAddFunctionOptionsType(options_type)); + ASSERT_OK(registry->AddFunctionOptionsType(options_type)); + ASSERT_RAISES(KeyError, registry->CanAddFunctionOptionsType(options_type)); + ASSERT_RAISES(KeyError, registry->AddFunctionOptionsType(options_type)); + ASSERT_OK(default_registry->CanAddFunctionOptionsType(options_type)); + } + } +} + +TEST(TestRegistry, RegisterNestedFunctions) { + auto default_registry = GetFunctionRegistry(); + std::shared_ptr func1 = std::make_shared( + "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + std::shared_ptr func2 = std::make_shared( + "f2", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = FunctionRegistry::Make(default_registry); + + ASSERT_OK(registry1->CanAddFunction(func1)); + ASSERT_OK(registry1->AddFunction(func1)); + + for (int j = 0; j < rounds; j++) { + auto registry2 = FunctionRegistry::Make(registry1.get()); + + ASSERT_OK(registry2->CanAddFunction(func2)); + ASSERT_OK(registry2->AddFunction(func2)); + ASSERT_RAISES(KeyError, registry2->CanAddFunction(func2)); + ASSERT_RAISES(KeyError, registry2->AddFunction(func2)); + ASSERT_OK(default_registry->CanAddFunction(func2)); + } + + ASSERT_RAISES(KeyError, registry1->CanAddFunction(func1)); + ASSERT_RAISES(KeyError, registry1->AddFunction(func1)); + ASSERT_OK(default_registry->CanAddFunction(func1)); + } +} + + } // namespace compute } // namespace arrow From fb7e9bd60b86480d3f056bbc5bd3383c3a4d83c4 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 27 May 2022 06:10:29 -0400 Subject: [PATCH 2/8] lint --- cpp/src/arrow/compute/registry.cc | 43 +++++++++++++------------- cpp/src/arrow/compute/registry_test.cc | 20 ++++++------ 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 772146d8ec7..5f0a43468c1 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -36,12 +36,12 @@ class FunctionRegistry::FunctionRegistryImpl { private: using FuncAdd = std::function)>; - const FuncAdd kFuncAddNoOp = - [](const std::string& name, std::shared_ptr func) {}; - const FuncAdd kFuncAddDo = - [this](const std::string& name, std::shared_ptr func) { - name_to_function_[name] = func; - }; + const FuncAdd kFuncAddNoOp = [](const std::string& name, + std::shared_ptr func) {}; + const FuncAdd kFuncAddDo = [this](const std::string& name, + std::shared_ptr func) { + name_to_function_[name] = func; + }; Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, FuncAdd add) { @@ -77,7 +77,7 @@ class FunctionRegistry::FunctionRegistryImpl { FuncAdd add) { std::lock_guard mutation_guard(lock_); - auto func_res = GetFunction(source_name); // must not acquire the mutex + auto func_res = GetFunction(source_name); // must not acquire the mutex if (!func_res.ok()) { return Status::KeyError("No function registered with name: ", source_name); } @@ -99,11 +99,11 @@ class FunctionRegistry::FunctionRegistryImpl { private: using FuncOptTypeAdd = std::function; - const FuncOptTypeAdd kFuncOptTypeAddNoOp = - [](const FunctionOptionsType* options_type) {}; + const FuncOptTypeAdd kFuncOptTypeAddNoOp = [](const FunctionOptionsType* options_type) { + }; const FuncOptTypeAdd kFuncOptTypeAddDo = [this](const FunctionOptionsType* options_type) { - name_to_options_type_[options_type->type_name()] = options_type; + name_to_options_type_[options_type->type_name()] = options_type; }; Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, @@ -174,40 +174,41 @@ class FunctionRegistry::NestedFunctionRegistryImpl Status CanAddFunction(std::shared_ptr function, bool allow_overwrite) override { return parent_->CanAddFunction(function, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::CanAddFunction(function, allow_overwrite); + FunctionRegistry::FunctionRegistryImpl::CanAddFunction(function, + allow_overwrite); } Status AddFunction(std::shared_ptr function, bool allow_overwrite) override { return parent_->CanAddFunction(function, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::AddFunction(function, allow_overwrite); + FunctionRegistry::FunctionRegistryImpl::AddFunction(function, allow_overwrite); } Status CanAddAlias(const std::string& target_name, const std::string& source_name) override { - Status st = FunctionRegistry::FunctionRegistryImpl::CanAddAlias(target_name, - source_name); + Status st = + FunctionRegistry::FunctionRegistryImpl::CanAddAlias(target_name, source_name); return st.ok() ? st : parent_->CanAddAlias(target_name, source_name); } Status AddAlias(const std::string& target_name, const std::string& source_name) override { - Status st = FunctionRegistry::FunctionRegistryImpl::AddAlias(target_name, - source_name); + Status st = + FunctionRegistry::FunctionRegistryImpl::AddAlias(target_name, source_name); return st.ok() ? st : parent_->AddAlias(target_name, source_name); } Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) override { return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::CanAddFunctionOptionsType( - options_type, allow_overwrite); + FunctionRegistry::FunctionRegistryImpl::CanAddFunctionOptionsType( + options_type, allow_overwrite); } Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) override { return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::AddFunctionOptionsType( - options_type, allow_overwrite); + FunctionRegistry::FunctionRegistryImpl::AddFunctionOptionsType( + options_type, allow_overwrite); } Result> GetFunction(const std::string& name) const override { @@ -238,7 +239,7 @@ class FunctionRegistry::NestedFunctionRegistryImpl int num_functions() const override { return parent_->num_functions() + - FunctionRegistry::FunctionRegistryImpl::num_functions(); + FunctionRegistry::FunctionRegistryImpl::num_functions(); } private: diff --git a/cpp/src/arrow/compute/registry_test.cc b/cpp/src/arrow/compute/registry_test.cc index 22a62ab73ad..4222bfa618f 100644 --- a/cpp/src/arrow/compute/registry_test.cc +++ b/cpp/src/arrow/compute/registry_test.cc @@ -96,15 +96,17 @@ TEST_P(TestRegistry, Basics) { } INSTANTIATE_TEST_SUITE_P( - TestRegistry, TestRegistry, - testing::Values( - std::make_tuple(static_cast( - []() { return FunctionRegistry::Make(); }), []() { return 0; }, - []() { return std::vector{}; }, "default"), - std::make_tuple(static_cast( - []() { return FunctionRegistry::Make(GetFunctionRegistry()); }), - []() { return GetFunctionRegistry()->num_functions(); }, - []() { return GetFunctionRegistry()->GetFunctionNames(); }, "nested"))); + TestRegistry, TestRegistry, + testing::Values( + std::make_tuple( + static_cast([]() { return FunctionRegistry::Make(); }), + []() { return 0; }, []() { return std::vector{}; }, "default"), + std::make_tuple( + static_cast([]() { + return FunctionRegistry::Make(GetFunctionRegistry()); + }), + []() { return GetFunctionRegistry()->num_functions(); }, + []() { return GetFunctionRegistry()->GetFunctionNames(); }, "nested"))); TEST(TestRegistry, RegisterTempFunctions) { auto default_registry = GetFunctionRegistry(); From 402888b7201584dbec975b76a7391f7cfc81a703 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 27 May 2022 06:30:22 -0400 Subject: [PATCH 3/8] lint --- cpp/src/arrow/compute/registry_test.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/registry_test.cc b/cpp/src/arrow/compute/registry_test.cc index 4222bfa618f..319b6be7c08 100644 --- a/cpp/src/arrow/compute/registry_test.cc +++ b/cpp/src/arrow/compute/registry_test.cc @@ -103,7 +103,7 @@ INSTANTIATE_TEST_SUITE_P( []() { return 0; }, []() { return std::vector{}; }, "default"), std::make_tuple( static_cast([]() { - return FunctionRegistry::Make(GetFunctionRegistry()); + return FunctionRegistry::Make(GetFunctionRegistry()); }), []() { return GetFunctionRegistry()->num_functions(); }, []() { return GetFunctionRegistry()->GetFunctionNames(); }, "nested"))); @@ -227,6 +227,5 @@ TEST(TestRegistry, RegisterNestedFunctions) { } } - } // namespace compute } // namespace arrow From 4a4b3a22329c3dd001a7cea1021a7a23e840e828 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 1 Jun 2022 05:28:04 -0400 Subject: [PATCH 4/8] add virtual dtor --- cpp/src/arrow/compute/registry.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 82b677ae72f..fa3fd380865 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -33,6 +33,9 @@ namespace arrow { namespace compute { class FunctionRegistry::FunctionRegistryImpl { + public: + ~FunctionRegistryImpl() {} + private: using FuncAdd = std::function)>; From c85421ad84f78780a12649720cb65971c0f3dc84 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 1 Jun 2022 08:06:33 -0400 Subject: [PATCH 5/8] fix virtual --- cpp/src/arrow/compute/registry.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index fa3fd380865..f214dbc6126 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -34,7 +34,7 @@ namespace compute { class FunctionRegistry::FunctionRegistryImpl { public: - ~FunctionRegistryImpl() {} + virtual ~FunctionRegistryImpl() {} private: using FuncAdd = std::function)>; From e39aa1f09836e298f3ede134a6d0a572567c1001 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 3 Jun 2022 02:59:24 -0400 Subject: [PATCH 6/8] requested fixes --- cpp/src/arrow/compute/registry.cc | 228 +++++++++---------------- cpp/src/arrow/compute/registry.h | 1 - cpp/src/arrow/compute/registry_test.cc | 23 +-- 3 files changed, 91 insertions(+), 161 deletions(-) diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index f214dbc6126..e11d0165a38 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -34,109 +34,59 @@ namespace compute { class FunctionRegistry::FunctionRegistryImpl { public: + explicit FunctionRegistryImpl(FunctionRegistryImpl* parent = NULLPTR) + : parent_(parent) {} virtual ~FunctionRegistryImpl() {} - private: - using FuncAdd = std::function)>; - - const FuncAdd kFuncAddNoOp = [](const std::string& name, - std::shared_ptr func) {}; - const FuncAdd kFuncAddDo = [this](const std::string& name, - std::shared_ptr func) { - name_to_function_[name] = func; - }; - - Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, - FuncAdd add) { -#ifndef NDEBUG - // This validates docstrings extensively, so don't waste time on it - // in release builds. - RETURN_NOT_OK(function->Validate()); -#endif - - std::lock_guard mutation_guard(lock_); - - const std::string& name = function->name(); - auto it = name_to_function_.find(name); - if (it != name_to_function_.end() && !allow_overwrite) { - return Status::KeyError("Already have a function registered with name: ", name); - } - add(name, std::move(function)); - return Status::OK(); - } - - public: virtual Status CanAddFunction(std::shared_ptr function, bool allow_overwrite) { - return DoAddFunction(function, allow_overwrite, kFuncAddNoOp); + return (parent_ == nullptr ? Status::OK() + : parent_->CanAddFunction(function, allow_overwrite)) & + DoAddFunction(function, allow_overwrite, /*add=*/false); } virtual Status AddFunction(std::shared_ptr function, bool allow_overwrite) { - return DoAddFunction(function, allow_overwrite, kFuncAddDo); + return (parent_ == nullptr ? Status::OK() + : parent_->CanAddFunction(function, allow_overwrite)) & + DoAddFunction(function, allow_overwrite, /*add=*/true); } - private: - Status DoAddAlias(const std::string& target_name, const std::string& source_name, - FuncAdd add) { - std::lock_guard mutation_guard(lock_); - - auto func_res = GetFunction(source_name); // must not acquire the mutex - if (!func_res.ok()) { - return Status::KeyError("No function registered with name: ", source_name); - } - add(target_name, func_res.ValueOrDie()); - return Status::OK(); - } - - public: virtual Status CanAddAlias(const std::string& target_name, const std::string& source_name) { - return DoAddAlias(target_name, source_name, kFuncAddNoOp); + Status st = DoAddAlias(target_name, source_name, /*add=*/false); + return st.ok() || parent_ == nullptr ? st + : parent_->CanAddAlias(target_name, source_name); } virtual Status AddAlias(const std::string& target_name, const std::string& source_name) { - return DoAddAlias(target_name, source_name, kFuncAddDo); - } - - private: - using FuncOptTypeAdd = std::function; - - const FuncOptTypeAdd kFuncOptTypeAddNoOp = [](const FunctionOptionsType* options_type) { - }; - const FuncOptTypeAdd kFuncOptTypeAddDo = - [this](const FunctionOptionsType* options_type) { - name_to_options_type_[options_type->type_name()] = options_type; - }; - - Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite, FuncOptTypeAdd add) { - std::lock_guard mutation_guard(lock_); - - const std::string name = options_type->type_name(); - auto it = name_to_options_type_.find(name); - if (it != name_to_options_type_.end() && !allow_overwrite) { - return Status::KeyError( - "Already have a function options type registered with name: ", name); - } - add(options_type); - return Status::OK(); + Status st = DoAddAlias(target_name, source_name, /*add=*/true); + return st.ok() || parent_ == nullptr ? st + : parent_->AddAlias(target_name, source_name); } - public: virtual Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) { - return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddNoOp); + return (parent_ == nullptr + ? Status::OK() + : parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)) & + DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/false); } virtual Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) { - return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddDo); + return (parent_ == nullptr + ? Status::OK() + : parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)) & + DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/true); } virtual Result> GetFunction(const std::string& name) const { auto it = name_to_function_.find(name); if (it == name_to_function_.end()) { + if (parent_ != nullptr) { + return parent_->GetFunction(name); + } return Status::KeyError("No function registered with name: ", name); } return it->second; @@ -144,6 +94,9 @@ class FunctionRegistry::FunctionRegistryImpl { virtual std::vector GetFunctionNames() const { std::vector results; + if (parent_ != nullptr) { + results = parent_->GetFunctionNames(); + } for (auto it : name_to_function_) { results.push_back(it.first); } @@ -155,98 +108,73 @@ class FunctionRegistry::FunctionRegistryImpl { const std::string& name) const { auto it = name_to_options_type_.find(name); if (it == name_to_options_type_.end()) { + if (parent_ != nullptr) { + return parent_->GetFunctionOptionsType(name); + } return Status::KeyError("No function options type registered with name: ", name); } return it->second; } - virtual int num_functions() const { return static_cast(name_to_function_.size()); } - - private: - std::mutex lock_; - std::unordered_map> name_to_function_; - std::unordered_map name_to_options_type_; -}; - -class FunctionRegistry::NestedFunctionRegistryImpl - : public FunctionRegistry::FunctionRegistryImpl { - public: - explicit NestedFunctionRegistryImpl(FunctionRegistry::FunctionRegistryImpl* parent) - : parent_(parent) {} - - Status CanAddFunction(std::shared_ptr function, - bool allow_overwrite) override { - return parent_->CanAddFunction(function, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::CanAddFunction(function, - allow_overwrite); + virtual int num_functions() const { + return (parent_ == nullptr ? 0 : parent_->num_functions()) + + static_cast(name_to_function_.size()); } - Status AddFunction(std::shared_ptr function, bool allow_overwrite) override { - return parent_->CanAddFunction(function, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::AddFunction(function, allow_overwrite); - } - - Status CanAddAlias(const std::string& target_name, - const std::string& source_name) override { - Status st = - FunctionRegistry::FunctionRegistryImpl::CanAddAlias(target_name, source_name); - return st.ok() ? st : parent_->CanAddAlias(target_name, source_name); - } + private: + Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, + bool add) { +#ifndef NDEBUG + // This validates docstrings extensively, so don't waste time on it + // in release builds. + RETURN_NOT_OK(function->Validate()); +#endif - Status AddAlias(const std::string& target_name, - const std::string& source_name) override { - Status st = - FunctionRegistry::FunctionRegistryImpl::AddAlias(target_name, source_name); - return st.ok() ? st : parent_->AddAlias(target_name, source_name); - } + std::lock_guard mutation_guard(lock_); - Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) override { - return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::CanAddFunctionOptionsType( - options_type, allow_overwrite); + const std::string& name = function->name(); + auto it = name_to_function_.find(name); + if (it != name_to_function_.end() && !allow_overwrite) { + return Status::KeyError("Already have a function registered with name: ", name); + } + if (add) { + name_to_function_[name] = std::move(function); + } + return Status::OK(); } - Status AddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) override { - return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::AddFunctionOptionsType( - options_type, allow_overwrite); - } + Status DoAddAlias(const std::string& target_name, const std::string& source_name, + bool add) { + std::lock_guard mutation_guard(lock_); - Result> GetFunction(const std::string& name) const override { - auto func_res = FunctionRegistry::FunctionRegistryImpl::GetFunction(name); - if (func_res.ok()) { - return func_res; + // following invocation must not acquire the mutex + ARROW_ASSIGN_OR_RAISE(auto func, GetFunction(source_name)); + if (add) { + name_to_function_[target_name] = func; } - return parent_->GetFunction(name); + return Status::OK(); } - std::vector GetFunctionNames() const override { - auto names = parent_->GetFunctionNames(); - auto more_names = FunctionRegistry::FunctionRegistryImpl::GetFunctionNames(); - names.insert(names.end(), std::make_move_iterator(more_names.begin()), - std::make_move_iterator(more_names.end())); - return names; - } + Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite, bool add) { + std::lock_guard mutation_guard(lock_); - Result GetFunctionOptionsType( - const std::string& name) const override { - auto options_type_res = - FunctionRegistry::FunctionRegistryImpl::GetFunctionOptionsType(name); - if (options_type_res.ok()) { - return options_type_res; + const std::string name = options_type->type_name(); + auto it = name_to_options_type_.find(name); + if (it != name_to_options_type_.end() && !allow_overwrite) { + return Status::KeyError( + "Already have a function options type registered with name: ", name); } - return parent_->GetFunctionOptionsType(name); - } - - int num_functions() const override { - return parent_->num_functions() + - FunctionRegistry::FunctionRegistryImpl::num_functions(); + if (add) { + name_to_options_type_[options_type->type_name()] = options_type; + } + return Status::OK(); } - private: - FunctionRegistry::FunctionRegistryImpl* parent_; + FunctionRegistryImpl* parent_; + std::mutex lock_; + std::unordered_map> name_to_function_; + std::unordered_map name_to_options_type_; }; std::unique_ptr FunctionRegistry::Make() { @@ -254,8 +182,8 @@ std::unique_ptr FunctionRegistry::Make() { } std::unique_ptr FunctionRegistry::Make(FunctionRegistry* parent) { - return std::unique_ptr(new FunctionRegistry( - new FunctionRegistry::NestedFunctionRegistryImpl(&*parent->impl_))); + return std::unique_ptr( + new FunctionRegistry(new FunctionRegistry::FunctionRegistryImpl(&*parent->impl_))); } std::unique_ptr FunctionRegistry::Make( diff --git a/cpp/src/arrow/compute/registry.h b/cpp/src/arrow/compute/registry.h index de074e10d92..485aaf31914 100644 --- a/cpp/src/arrow/compute/registry.h +++ b/cpp/src/arrow/compute/registry.h @@ -110,7 +110,6 @@ class ARROW_EXPORT FunctionRegistry { explicit FunctionRegistry(FunctionRegistryImpl* impl); class NestedFunctionRegistryImpl; - friend class NestedFunctionRegistryImpl; }; /// \brief Return the process-global function registry diff --git a/cpp/src/arrow/compute/registry_test.cc b/cpp/src/arrow/compute/registry_test.cc index 319b6be7c08..937515af4ac 100644 --- a/cpp/src/arrow/compute/registry_test.cc +++ b/cpp/src/arrow/compute/registry_test.cc @@ -54,7 +54,7 @@ TEST_P(TestRegistry, Basics) { int n_funcs = get_num_funcs(); auto get_func_names = std::get<2>(GetParam()); std::vector func_names = get_func_names(); - ASSERT_EQ(n_funcs + 0, registry_->num_functions()); + ASSERT_EQ(n_funcs, registry_->num_functions()); std::shared_ptr func = std::make_shared( "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); @@ -86,6 +86,7 @@ TEST_P(TestRegistry, Basics) { for (auto name : {"f0", "f1"}) { expected_names.push_back(name); } + std::sort(expected_names.begin(), expected_names.end()); ASSERT_EQ(expected_names, registry_->GetFunctionNames()); // Aliases @@ -145,22 +146,23 @@ TEST(TestRegistry, RegisterTempAliases) { } } -template +template class ExampleOptions : public FunctionOptions { public: explicit ExampleOptions(std::shared_ptr value); std::shared_ptr value; }; -template +template class ExampleOptionsType : public FunctionOptionsType { public: static const FunctionOptionsType* GetInstance() { - static std::unique_ptr instance(new ExampleOptionsType()); + static std::unique_ptr instance( + new ExampleOptionsType()); return instance.get(); } const char* type_name() const override { - static std::string name = std::string("example") + std::to_string(N); + static std::string name = std::string("example") + std::to_string(kExampleSeqNum); return name.c_str(); } std::string Stringify(const FunctionOptions& options) const override { @@ -171,13 +173,14 @@ class ExampleOptionsType : public FunctionOptionsType { return true; } std::unique_ptr Copy(const FunctionOptions& options) const override { - const auto& opts = static_cast&>(options); - return arrow::internal::make_unique>(opts.value); + const auto& opts = static_cast&>(options); + return arrow::internal::make_unique>(opts.value); } }; -template -ExampleOptions::ExampleOptions(std::shared_ptr value) - : FunctionOptions(ExampleOptionsType::GetInstance()), value(std::move(value)) {} +template +ExampleOptions::ExampleOptions(std::shared_ptr value) + : FunctionOptions(ExampleOptionsType::GetInstance()), + value(std::move(value)) {} TEST(TestRegistry, RegisterTempFunctionOptionsType) { auto default_registry = GetFunctionRegistry(); From 378f55419dedc9cea442cd8a081688cd63025b32 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 9 Jun 2022 06:07:34 -0400 Subject: [PATCH 7/8] requested fixes --- cpp/src/arrow/compute/registry.cc | 143 ++++++++++++++++++------------ cpp/src/arrow/compute/registry.h | 63 +++++++------ 2 files changed, 120 insertions(+), 86 deletions(-) diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index e11d0165a38..efe42d90311 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -36,55 +36,61 @@ class FunctionRegistry::FunctionRegistryImpl { public: explicit FunctionRegistryImpl(FunctionRegistryImpl* parent = NULLPTR) : parent_(parent) {} - virtual ~FunctionRegistryImpl() {} + ~FunctionRegistryImpl() {} - virtual Status CanAddFunction(std::shared_ptr function, - bool allow_overwrite) { - return (parent_ == nullptr ? Status::OK() - : parent_->CanAddFunction(function, allow_overwrite)) & - DoAddFunction(function, allow_overwrite, /*add=*/false); + Status CanAddFunction(std::shared_ptr function, + bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); + } + return DoAddFunction(function, allow_overwrite, /*add=*/false); } - virtual Status AddFunction(std::shared_ptr function, bool allow_overwrite) { - return (parent_ == nullptr ? Status::OK() - : parent_->CanAddFunction(function, allow_overwrite)) & - DoAddFunction(function, allow_overwrite, /*add=*/true); + Status AddFunction(std::shared_ptr function, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); + } + return DoAddFunction(function, allow_overwrite, /*add=*/true); } - virtual Status CanAddAlias(const std::string& target_name, - const std::string& source_name) { - Status st = DoAddAlias(target_name, source_name, /*add=*/false); - return st.ok() || parent_ == nullptr ? st - : parent_->CanAddAlias(target_name, source_name); + Status CanAddAlias(const std::string& target_name, + const std::string& source_name) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); + } + return DoAddAlias(target_name, source_name, /*add=*/false); } - virtual Status AddAlias(const std::string& target_name, - const std::string& source_name) { - Status st = DoAddAlias(target_name, source_name, /*add=*/true); - return st.ok() || parent_ == nullptr ? st - : parent_->AddAlias(target_name, source_name); + Status AddAlias(const std::string& target_name, + const std::string& source_name) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); + } + return DoAddAlias(target_name, source_name, /*add=*/true); } - virtual Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) { - return (parent_ == nullptr - ? Status::OK() - : parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)) & - DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/false); + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); + } + return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/false); } - virtual Status AddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) { - return (parent_ == nullptr - ? Status::OK() - : parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)) & - DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/true); + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); + } + return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/true); } - virtual Result> GetFunction(const std::string& name) const { + Result> GetFunction(const std::string& name) const { auto it = name_to_function_.find(name); if (it == name_to_function_.end()) { - if (parent_ != nullptr) { + if (parent_ != NULLPTR) { return parent_->GetFunction(name); } return Status::KeyError("No function registered with name: ", name); @@ -92,9 +98,9 @@ class FunctionRegistry::FunctionRegistryImpl { return it->second; } - virtual std::vector GetFunctionNames() const { + std::vector GetFunctionNames() const { std::vector results; - if (parent_ != nullptr) { + if (parent_ != NULLPTR) { results = parent_->GetFunctionNames(); } for (auto it : name_to_function_) { @@ -104,11 +110,11 @@ class FunctionRegistry::FunctionRegistryImpl { return results; } - virtual Result GetFunctionOptionsType( + Result GetFunctionOptionsType( const std::string& name) const { auto it = name_to_options_type_.find(name); if (it == name_to_options_type_.end()) { - if (parent_ != nullptr) { + if (parent_ != NULLPTR) { return parent_->GetFunctionOptionsType(name); } return Status::KeyError("No function options type registered with name: ", name); @@ -116,12 +122,41 @@ class FunctionRegistry::FunctionRegistryImpl { return it->second; } - virtual int num_functions() const { - return (parent_ == nullptr ? 0 : parent_->num_functions()) + + int num_functions() const { + return (parent_ == NULLPTR ? 0 : parent_->num_functions()) + static_cast(name_to_function_.size()); } private: + // must not acquire mutex + Status CanAddFunctionName(const std::string& name, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(name, allow_overwrite)); + } + if (!allow_overwrite) { + auto it = name_to_function_.find(name); + if (it != name_to_function_.end()) { + return Status::KeyError("Already have a function registered with name: ", name); + } + } + return Status::OK(); + } + + // must not acquire mutex + Status CanAddOptionsTypeName(const std::string& name, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddOptionsTypeName(name, allow_overwrite)); + } + if (!allow_overwrite) { + auto it = name_to_options_type_.find(name); + if (it != name_to_options_type_.end()) { + return Status::KeyError( + "Already have a function options type registered with name: ", name); + } + } + return Status::OK(); + } + Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, bool add) { #ifndef NDEBUG @@ -133,10 +168,7 @@ class FunctionRegistry::FunctionRegistryImpl { std::lock_guard mutation_guard(lock_); const std::string& name = function->name(); - auto it = name_to_function_.find(name); - if (it != name_to_function_.end() && !allow_overwrite) { - return Status::KeyError("Already have a function registered with name: ", name); - } + RETURN_NOT_OK(CanAddFunctionName(name, allow_overwrite)); if (add) { name_to_function_[name] = std::move(function); } @@ -145,10 +177,14 @@ class FunctionRegistry::FunctionRegistryImpl { Status DoAddAlias(const std::string& target_name, const std::string& source_name, bool add) { + // source name must exist in this registry or the parent + // check outside mutex, in case GetFunction leads to mutex acquisition + ARROW_ASSIGN_OR_RAISE(auto func, GetFunction(source_name)); + std::lock_guard mutation_guard(lock_); - // following invocation must not acquire the mutex - ARROW_ASSIGN_OR_RAISE(auto func, GetFunction(source_name)); + // target name must be available in this registry and the parent + RETURN_NOT_OK(CanAddFunctionName(target_name, /*allow_overwrite=*/false)); if (add) { name_to_function_[target_name] = func; } @@ -160,11 +196,7 @@ class FunctionRegistry::FunctionRegistryImpl { std::lock_guard mutation_guard(lock_); const std::string name = options_type->type_name(); - auto it = name_to_options_type_.find(name); - if (it != name_to_options_type_.end() && !allow_overwrite) { - return Status::KeyError( - "Already have a function options type registered with name: ", name); - } + RETURN_NOT_OK(CanAddOptionsTypeName(name, /*allow_overwrite=*/false)); if (add) { name_to_options_type_[options_type->type_name()] = options_type; } @@ -182,13 +214,8 @@ std::unique_ptr FunctionRegistry::Make() { } std::unique_ptr FunctionRegistry::Make(FunctionRegistry* parent) { - return std::unique_ptr( - new FunctionRegistry(new FunctionRegistry::FunctionRegistryImpl(&*parent->impl_))); -} - -std::unique_ptr FunctionRegistry::Make( - std::unique_ptr parent) { - return FunctionRegistry::Make(&*parent); + return std::unique_ptr(new FunctionRegistry( + new FunctionRegistry::FunctionRegistryImpl(parent->impl_.get()))); } FunctionRegistry::FunctionRegistry() : FunctionRegistry(new FunctionRegistryImpl()) {} diff --git a/cpp/src/arrow/compute/registry.h b/cpp/src/arrow/compute/registry.h index 485aaf31914..16a76255946 100644 --- a/cpp/src/arrow/compute/registry.h +++ b/cpp/src/arrow/compute/registry.h @@ -45,59 +45,66 @@ class FunctionOptionsType; /// lower-level function execution. class ARROW_EXPORT FunctionRegistry { public: - virtual ~FunctionRegistry(); + ~FunctionRegistry(); - /// \brief Construct a new registry. Most users only need to use the global - /// registry + /// \brief Construct a new registry. + /// + /// Most users only need to use the global registry. static std::unique_ptr Make(); - /// \brief Construct a new nested registry with the given parent. Most users only need - /// to use the global registry + /// \brief Construct a new nested registry with the given parent. + /// + /// Most users only need to use the global registry. The returned registry never changes + /// its parent, even when an operation allows overwritting. static std::unique_ptr Make(FunctionRegistry* parent); - /// \brief Construct a new nested registry with the given parent. Most users only need - /// to use the global registry - static std::unique_ptr Make(std::unique_ptr parent); - - /// \brief Checks whether a new function can be added to the registry. Returns - /// Status::KeyError if a function with the same name is already registered + /// \brief Check whether a new function can be added to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. Status CanAddFunction(std::shared_ptr function, bool allow_overwrite = false); - /// \brief Add a new function to the registry. Returns Status::KeyError if a - /// function with the same name is already registered + /// \brief Add a new function to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. Status AddFunction(std::shared_ptr function, bool allow_overwrite = false); - /// \brief Checks whether an alias can be added for the given function name. Returns - /// Status::KeyError if the function with the given name is not registered + /// \brief Check whether an alias can be added for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. Status CanAddAlias(const std::string& target_name, const std::string& source_name); - /// \brief Add alias for the given function name. Returns Status::KeyError if the - /// function with the given name is not registered + /// \brief Add alias for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. Status AddAlias(const std::string& target_name, const std::string& source_name); - /// \brief Checks whether a new function options type can be added to the registry. - /// Returns Status::KeyError if a function options type with the same name is already - /// registered + /// \brief Check whether a new function options type can be added to the registry. + /// + /// \returns Status::KeyError if a function options type with the same name is already + /// registered. Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false); - /// \brief Add a new function options type to the registry. Returns Status::KeyError if - /// a function options type with the same name is already registered + /// \brief Add a new function options type to the registry. + /// + /// \returns Status::KeyError if a function options type with the same name is already + /// registered. Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false); - /// \brief Retrieve a function by name from the registry + /// \brief Retrieve a function by name from the registry. Result> GetFunction(const std::string& name) const; - /// \brief Return vector of all entry names in the registry. Helpful for - /// displaying a manifest of available functions + /// \brief Return vector of all entry names in the registry. + /// + /// Helpful for displaying a manifest of available functions. std::vector GetFunctionNames() const; - /// \brief Retrieve a function options type by name from the registry + /// \brief Retrieve a function options type by name from the registry. Result GetFunctionOptionsType( const std::string& name) const; - /// \brief The number of currently registered functions + /// \brief The number of currently registered functions. int num_functions() const; private: @@ -112,7 +119,7 @@ class ARROW_EXPORT FunctionRegistry { class NestedFunctionRegistryImpl; }; -/// \brief Return the process-global function registry +/// \brief Return the process-global function registry. ARROW_EXPORT FunctionRegistry* GetFunctionRegistry(); } // namespace compute From d7a5f9d1562e9423eb59a3bcd377dad87db70970 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 9 Jun 2022 06:53:51 -0400 Subject: [PATCH 8/8] lint --- cpp/src/arrow/compute/registry.cc | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index efe42d90311..fe7c6fa8ad1 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -38,35 +38,32 @@ class FunctionRegistry::FunctionRegistryImpl { : parent_(parent) {} ~FunctionRegistryImpl() {} - Status CanAddFunction(std::shared_ptr function, - bool allow_overwrite) { + Status CanAddFunction(std::shared_ptr function, bool allow_overwrite) { if (parent_ != NULLPTR) { - RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); } return DoAddFunction(function, allow_overwrite, /*add=*/false); } Status AddFunction(std::shared_ptr function, bool allow_overwrite) { if (parent_ != NULLPTR) { - RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); } return DoAddFunction(function, allow_overwrite, /*add=*/true); } - Status CanAddAlias(const std::string& target_name, - const std::string& source_name) { + Status CanAddAlias(const std::string& target_name, const std::string& source_name) { if (parent_ != NULLPTR) { - RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, - /*allow_overwrite=*/false)); + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); } return DoAddAlias(target_name, source_name, /*add=*/false); } - Status AddAlias(const std::string& target_name, - const std::string& source_name) { + Status AddAlias(const std::string& target_name, const std::string& source_name) { if (parent_ != NULLPTR) { - RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, - /*allow_overwrite=*/false)); + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); } return DoAddAlias(target_name, source_name, /*add=*/true); } @@ -74,7 +71,7 @@ class FunctionRegistry::FunctionRegistryImpl { Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) { if (parent_ != NULLPTR) { - RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); } return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/false); } @@ -82,7 +79,7 @@ class FunctionRegistry::FunctionRegistryImpl { Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false) { if (parent_ != NULLPTR) { - RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); } return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/true); }