Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 147 additions & 30 deletions cpp/src/arrow/compute/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,59 +34,72 @@ namespace compute {

class FunctionRegistry::FunctionRegistryImpl {
public:
Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
#ifndef NDEBUG
// This validates docstrings extensively, so don't waste time on it
// in release builds.
RETURN_NOT_OK(function->Validate());
#endif
explicit FunctionRegistryImpl(FunctionRegistryImpl* parent = NULLPTR)
: parent_(parent) {}
~FunctionRegistryImpl() {}

std::lock_guard<std::mutex> mutation_guard(lock_);
Status CanAddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite));
}
return DoAddFunction(function, allow_overwrite, /*add=*/false);
}

const std::string& name = function->name();
auto it = name_to_function_.find(name);
if (it != name_to_function_.end() && !allow_overwrite) {
return Status::KeyError("Already have a function registered with name: ", name);
Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite));
}
name_to_function_[name] = std::move(function);
return Status::OK();
return DoAddFunction(function, allow_overwrite, /*add=*/true);
}

Status CanAddAlias(const std::string& target_name, const std::string& source_name) {
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddFunctionName(target_name,
/*allow_overwrite=*/false));
}
return DoAddAlias(target_name, source_name, /*add=*/false);
}

Status AddAlias(const std::string& target_name, const std::string& source_name) {
std::lock_guard<std::mutex> mutation_guard(lock_);
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddFunctionName(target_name,
/*allow_overwrite=*/false));
}
return DoAddAlias(target_name, source_name, /*add=*/true);
}

auto it = name_to_function_.find(source_name);
if (it == name_to_function_.end()) {
return Status::KeyError("No function registered with name: ", source_name);
Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite = false) {
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite));
}
name_to_function_[target_name] = it->second;
return Status::OK();
return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/false);
}

Status AddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite = false) {
std::lock_guard<std::mutex> mutation_guard(lock_);

const std::string name = options_type->type_name();
auto it = name_to_options_type_.find(name);
if (it != name_to_options_type_.end() && !allow_overwrite) {
return Status::KeyError(
"Already have a function options type registered with name: ", name);
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite));
}
name_to_options_type_[name] = options_type;
return Status::OK();
return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/true);
}

Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const {
auto it = name_to_function_.find(name);
if (it == name_to_function_.end()) {
if (parent_ != NULLPTR) {
return parent_->GetFunction(name);
}
return Status::KeyError("No function registered with name: ", name);
}
return it->second;
}

std::vector<std::string> GetFunctionNames() const {
std::vector<std::string> results;
if (parent_ != NULLPTR) {
results = parent_->GetFunctionNames();
}
for (auto it : name_to_function_) {
results.push_back(it.first);
}
Expand All @@ -98,14 +111,96 @@ class FunctionRegistry::FunctionRegistryImpl {
const std::string& name) const {
auto it = name_to_options_type_.find(name);
if (it == name_to_options_type_.end()) {
if (parent_ != NULLPTR) {
return parent_->GetFunctionOptionsType(name);
}
return Status::KeyError("No function options type registered with name: ", name);
}
return it->second;
}

int num_functions() const { return static_cast<int>(name_to_function_.size()); }
int num_functions() const {
return (parent_ == NULLPTR ? 0 : parent_->num_functions()) +
static_cast<int>(name_to_function_.size());
}

private:
// must not acquire mutex
Status CanAddFunctionName(const std::string& name, bool allow_overwrite) {
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddFunctionName(name, allow_overwrite));
}
if (!allow_overwrite) {
auto it = name_to_function_.find(name);
if (it != name_to_function_.end()) {
return Status::KeyError("Already have a function registered with name: ", name);
}
}
return Status::OK();
}

// must not acquire mutex
Status CanAddOptionsTypeName(const std::string& name, bool allow_overwrite) {
if (parent_ != NULLPTR) {
RETURN_NOT_OK(parent_->CanAddOptionsTypeName(name, allow_overwrite));
}
if (!allow_overwrite) {
auto it = name_to_options_type_.find(name);
if (it != name_to_options_type_.end()) {
return Status::KeyError(
"Already have a function options type registered with name: ", name);
}
}
return Status::OK();
}

Status DoAddFunction(std::shared_ptr<Function> function, bool allow_overwrite,
bool add) {
#ifndef NDEBUG
// This validates docstrings extensively, so don't waste time on it
// in release builds.
RETURN_NOT_OK(function->Validate());
#endif

std::lock_guard<std::mutex> mutation_guard(lock_);

const std::string& name = function->name();
RETURN_NOT_OK(CanAddFunctionName(name, allow_overwrite));
if (add) {
name_to_function_[name] = std::move(function);
}
return Status::OK();
}

Status DoAddAlias(const std::string& target_name, const std::string& source_name,
bool add) {
// source name must exist in this registry or the parent
// check outside mutex, in case GetFunction leads to mutex acquisition
ARROW_ASSIGN_OR_RAISE(auto func, GetFunction(source_name));

std::lock_guard<std::mutex> mutation_guard(lock_);

// target name must be available in this registry and the parent
RETURN_NOT_OK(CanAddFunctionName(target_name, /*allow_overwrite=*/false));
if (add) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be doing a check here to see if a function with target_name is already registered (would probably also necessitate allow_overwrite)?

name_to_function_[target_name] = func;
}
return Status::OK();
}

Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite, bool add) {
std::lock_guard<std::mutex> mutation_guard(lock_);

const std::string name = options_type->type_name();
RETURN_NOT_OK(CanAddOptionsTypeName(name, /*allow_overwrite=*/false));
if (add) {
name_to_options_type_[options_type->type_name()] = options_type;
}
return Status::OK();
}

FunctionRegistryImpl* parent_;
std::mutex lock_;
std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_;
std::unordered_map<std::string, const FunctionOptionsType*> name_to_options_type_;
Expand All @@ -115,20 +210,42 @@ std::unique_ptr<FunctionRegistry> FunctionRegistry::Make() {
return std::unique_ptr<FunctionRegistry>(new FunctionRegistry());
}

FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); }
std::unique_ptr<FunctionRegistry> FunctionRegistry::Make(FunctionRegistry* parent) {
return std::unique_ptr<FunctionRegistry>(new FunctionRegistry(
new FunctionRegistry::FunctionRegistryImpl(parent->impl_.get())));
}

FunctionRegistry::FunctionRegistry() : FunctionRegistry(new FunctionRegistryImpl()) {}

FunctionRegistry::FunctionRegistry(FunctionRegistryImpl* impl) { impl_.reset(impl); }

FunctionRegistry::~FunctionRegistry() {}

Status FunctionRegistry::CanAddFunction(std::shared_ptr<Function> function,
bool allow_overwrite) {
return impl_->CanAddFunction(std::move(function), allow_overwrite);
}

Status FunctionRegistry::AddFunction(std::shared_ptr<Function> function,
bool allow_overwrite) {
return impl_->AddFunction(std::move(function), allow_overwrite);
}

Status FunctionRegistry::CanAddAlias(const std::string& target_name,
const std::string& source_name) {
return impl_->CanAddAlias(target_name, source_name);
}

Status FunctionRegistry::AddAlias(const std::string& target_name,
const std::string& source_name) {
return impl_->AddAlias(target_name, source_name);
}

Status FunctionRegistry::CanAddFunctionOptionsType(
const FunctionOptionsType* options_type, bool allow_overwrite) {
return impl_->CanAddFunctionOptionsType(options_type, allow_overwrite);
}

Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite) {
return impl_->AddFunctionOptionsType(options_type, allow_overwrite);
Expand Down
61 changes: 47 additions & 14 deletions cpp/src/arrow/compute/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,35 +47,64 @@ class ARROW_EXPORT FunctionRegistry {
public:
~FunctionRegistry();

/// \brief Construct a new registry. Most users only need to use the global
/// registry
/// \brief Construct a new registry.
///
/// Most users only need to use the global registry.
static std::unique_ptr<FunctionRegistry> Make();

/// \brief Add a new function to the registry. Returns Status::KeyError if a
/// function with the same name is already registered
/// \brief Construct a new nested registry with the given parent.
///
/// Most users only need to use the global registry. The returned registry never changes
/// its parent, even when an operation allows overwritting.
static std::unique_ptr<FunctionRegistry> Make(FunctionRegistry* parent);

/// \brief Check whether a new function can be added to the registry.
///
/// \returns Status::KeyError if a function with the same name is already registered.
Status CanAddFunction(std::shared_ptr<Function> function, bool allow_overwrite = false);

/// \brief Add a new function to the registry.
///
/// \returns Status::KeyError if a function with the same name is already registered.
Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite = false);

/// \brief Add aliases for the given function name. Returns Status::KeyError if the
/// function with the given name is not registered
/// \brief Check whether an alias can be added for the given function name.
///
/// \returns Status::KeyError if the function with the given name is not registered.
Status CanAddAlias(const std::string& target_name, const std::string& source_name);

/// \brief Add alias for the given function name.
///
/// \returns Status::KeyError if the function with the given name is not registered.
Status AddAlias(const std::string& target_name, const std::string& source_name);

/// \brief Add a new function options type to the registry. Returns Status::KeyError if
/// a function options type with the same name is already registered
/// \brief Check whether a new function options type can be added to the registry.
///
/// \returns Status::KeyError if a function options type with the same name is already
/// registered.
Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite = false);

/// \brief Add a new function options type to the registry.
///
/// \returns Status::KeyError if a function options type with the same name is already
/// registered.
Status AddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite = false);

/// \brief Retrieve a function by name from the registry
/// \brief Retrieve a function by name from the registry.
Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const;

/// \brief Return vector of all entry names in the registry. Helpful for
/// displaying a manifest of available functions
/// \brief Return vector of all entry names in the registry.
///
/// Helpful for displaying a manifest of available functions.
std::vector<std::string> GetFunctionNames() const;

/// \brief Retrieve a function options type by name from the registry
/// \brief Retrieve a function options type by name from the registry.
Result<const FunctionOptionsType*> GetFunctionOptionsType(
const std::string& name) const;

/// \brief The number of currently registered functions
/// \brief The number of currently registered functions.
int num_functions() const;

private:
Expand All @@ -84,9 +113,13 @@ class ARROW_EXPORT FunctionRegistry {
// Use PIMPL pattern to not have std::unordered_map here
class FunctionRegistryImpl;
std::unique_ptr<FunctionRegistryImpl> impl_;

explicit FunctionRegistry(FunctionRegistryImpl* impl);

class NestedFunctionRegistryImpl;
};

/// \brief Return the process-global function registry
/// \brief Return the process-global function registry.
ARROW_EXPORT FunctionRegistry* GetFunctionRegistry();

} // namespace compute
Expand Down
Loading