Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 216 additions & 121 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ struct ExtensionSet::Impl {
std::unordered_map<Id, uint32_t, IdHashEq, IdHashEq> types_, functions_;
};

ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry)
ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry)
: registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {}

Result<ExtensionSet> ExtensionSet::Make(std::vector<util::string_view> uris,
std::vector<Id> type_ids,
std::vector<bool> type_is_variation,
std::vector<Id> function_ids,
ExtensionIdRegistry* registry) {
const ExtensionIdRegistry* registry) {
ExtensionSet set;
set.registry_ = registry;

Expand Down Expand Up @@ -210,158 +210,253 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) {
return &it->second;
}

ExtensionIdRegistry* default_extension_id_registry() {
static struct Impl : ExtensionIdRegistry {
Impl() {
struct TypeName {
std::shared_ptr<DataType> type;
util::string_view name;
};

// The type (variation) mappings listed below need to be kept in sync
// with the YAML at substrait/format/extension_types.yaml manually;
// see ARROW-15535.
for (TypeName e : {
TypeName{uint8(), "u8"},
TypeName{uint16(), "u16"},
TypeName{uint32(), "u32"},
TypeName{uint64(), "u64"},
TypeName{float16(), "fp16"},
}) {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
/*is_variation=*/true));
}

for (TypeName e : {
TypeName{null(), "null"},
TypeName{month_interval(), "interval_month"},
TypeName{day_time_interval(), "interval_day_milli"},
TypeName{month_day_nano_interval(), "interval_month_day_nano"},
}) {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
/*is_variation=*/false));
}

// TODO: this is just a placeholder right now. We'll need a YAML file for
// all functions (and prototypes) that Arrow provides that are relevant
// for Substrait, and include mappings for all of them here. See
// ARROW-15535.
for (util::string_view name : {
"add",
}) {
DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
}
}
namespace {

struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
virtual ~ExtensionIdRegistryImpl() {}

std::vector<util::string_view> Uris() const override {
return {uris_.begin(), uris_.end()};
std::vector<util::string_view> Uris() const override {
return {uris_.begin(), uris_.end()};
}

util::optional<TypeRecord> GetType(const DataType& type) const override {
if (auto index = GetIndex(type_to_index_, &type)) {
return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
}
return {};
}

util::optional<TypeRecord> GetType(const DataType& type) const override {
if (auto index = GetIndex(type_to_index_, &type)) {
return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
}
return {};
util::optional<TypeRecord> GetType(Id id, bool is_variation) const override {
if (auto index = GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) {
return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
}
return {};
}

util::optional<TypeRecord> GetType(Id id, bool is_variation) const override {
if (auto index =
GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) {
return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
}
return {};
Status CanRegisterType(Id id, std::shared_ptr<DataType> type,
bool is_variation) const override {
auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_;
if (id_to_index.find(id) != id_to_index.end()) {
return Status::Invalid("Type id was already registered");
}
if (type_to_index_.find(&*type) != type_to_index_.end()) {
return Status::Invalid("Type was already registered");
}
return Status::OK();
}

Status RegisterType(Id id, std::shared_ptr<DataType> type,
bool is_variation) override {
DCHECK_EQ(type_ids_.size(), types_.size());
DCHECK_EQ(type_ids_.size(), type_is_variation_.size());
Status RegisterType(Id id, std::shared_ptr<DataType> type, bool is_variation) override {
DCHECK_EQ(type_ids_.size(), types_.size());
DCHECK_EQ(type_ids_.size(), type_is_variation_.size());

Id copied_id{*uris_.emplace(id.uri.to_string()).first,
*names_.emplace(id.name.to_string()).first};
Id copied_id{*uris_.emplace(id.uri.to_string()).first,
*names_.emplace(id.name.to_string()).first};

auto index = static_cast<int>(type_ids_.size());
auto index = static_cast<int>(type_ids_.size());

auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_;
auto it_success = id_to_index->emplace(copied_id, index);
auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_;
auto it_success = id_to_index->emplace(copied_id, index);

if (!it_success.second) {
return Status::Invalid("Type id was already registered");
}
if (!it_success.second) {
return Status::Invalid("Type id was already registered");
}

if (!type_to_index_.emplace(type.get(), index).second) {
id_to_index->erase(it_success.first);
return Status::Invalid("Type was already registered");
}
if (!type_to_index_.emplace(type.get(), index).second) {
id_to_index->erase(it_success.first);
return Status::Invalid("Type was already registered");
}

type_ids_.push_back(copied_id);
types_.push_back(std::move(type));
type_is_variation_.push_back(is_variation);
return Status::OK();
}

type_ids_.push_back(copied_id);
types_.push_back(std::move(type));
type_is_variation_.push_back(is_variation);
return Status::OK();
util::optional<FunctionRecord> GetFunction(
util::string_view arrow_function_name) const override {
if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
}
return {};
}

util::optional<FunctionRecord> GetFunction(Id id) const override {
if (auto index = GetIndex(function_id_to_index_, id)) {
return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
}
return {};
}

Status CanRegisterFunction(Id id, std::string arrow_function_name) const override {
if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
return Status::Invalid("Function id was already registered");
}
if (function_name_to_index_.find(arrow_function_name) !=
function_name_to_index_.end()) {
return Status::Invalid("Function name was already registered");
}
return Status::OK();
}

Status RegisterFunction(Id id, std::string arrow_function_name) override {
DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());

Id copied_id{*uris_.emplace(id.uri.to_string()).first,
*names_.emplace(id.name.to_string()).first};

const std::string& copied_function_name{
*function_names_.emplace(std::move(arrow_function_name)).first};

auto index = static_cast<int>(function_ids_.size());

auto it_success = function_id_to_index_.emplace(copied_id, index);

util::optional<FunctionRecord> GetFunction(
util::string_view arrow_function_name) const override {
if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
}
return {};
if (!it_success.second) {
return Status::Invalid("Function id was already registered");
}

util::optional<FunctionRecord> GetFunction(Id id) const override {
if (auto index = GetIndex(function_id_to_index_, id)) {
return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
}
return {};
if (!function_name_to_index_.emplace(copied_function_name, index).second) {
function_id_to_index_.erase(it_success.first);
return Status::Invalid("Function name was already registered");
}

Status RegisterFunction(Id id, std::string arrow_function_name) override {
DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
function_name_ptrs_.push_back(&copied_function_name);
function_ids_.push_back(copied_id);
return Status::OK();
}

Id copied_id{*uris_.emplace(id.uri.to_string()).first,
*names_.emplace(id.name.to_string()).first};
// owning storage of uris, names, (arrow::)function_names, types
// note that storing strings like this is safe since references into an
// unordered_set are not invalidated on insertion
std::unordered_set<std::string> uris_, names_, function_names_;
DataTypeVector types_;
std::vector<bool> type_is_variation_;

// non-owning lookup helpers
std::vector<Id> type_ids_, function_ids_;
std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_, variation_id_to_index_;
std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;

std::vector<const std::string*> function_name_ptrs_;
std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
function_name_to_index_;
};

const std::string& copied_function_name{
*function_names_.emplace(std::move(arrow_function_name)).first};
struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
: parent_(parent) {}

auto index = static_cast<int>(function_ids_.size());
virtual ~NestedExtensionIdRegistryImpl() {}

auto it_success = function_id_to_index_.emplace(copied_id, index);
std::vector<util::string_view> Uris() const override {
std::vector<util::string_view> uris = parent_->Uris();
std::unordered_set<util::string_view> uri_set;
uri_set.insert(uris.begin(), uris.end());
uri_set.insert(uris_.begin(), uris_.end());
return std::vector<util::string_view>(uris);
}

if (!it_success.second) {
return Status::Invalid("Function id was already registered");
}
util::optional<TypeRecord> GetType(const DataType& type) const override {
auto type_opt = ExtensionIdRegistryImpl::GetType(type);
if (type_opt) {
return type_opt;
}
return parent_->GetType(type);
}

util::optional<TypeRecord> GetType(Id id, bool is_variation) const override {
auto type_opt = ExtensionIdRegistryImpl::GetType(id, is_variation);
if (type_opt) {
return type_opt;
}
return parent_->GetType(id, is_variation);
}

if (!function_name_to_index_.emplace(copied_function_name, index).second) {
function_id_to_index_.erase(it_success.first);
return Status::Invalid("Function name was already registered");
}
Status RegisterType(Id id, std::shared_ptr<DataType> type, bool is_variation) override {
return parent_->CanRegisterType(id, type, is_variation) &
ExtensionIdRegistryImpl::RegisterType(id, type, is_variation);
}

function_name_ptrs_.push_back(&copied_function_name);
function_ids_.push_back(copied_id);
return Status::OK();
util::optional<FunctionRecord> GetFunction(
util::string_view arrow_function_name) const override {
auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name);
if (func_opt) {
return func_opt;
}
return parent_->GetFunction(arrow_function_name);
}

// owning storage of uris, names, (arrow::)function_names, types
// note that storing strings like this is safe since references into an
// unordered_set are not invalidated on insertion
std::unordered_set<std::string> uris_, names_, function_names_;
DataTypeVector types_;
std::vector<bool> type_is_variation_;
util::optional<FunctionRecord> GetFunction(Id id) const override {
auto func_opt = ExtensionIdRegistryImpl::GetFunction(id);
if (func_opt) {
return func_opt;
}
return parent_->GetFunction(id);
}

// non-owning lookup helpers
std::vector<Id> type_ids_, function_ids_;
std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_, variation_id_to_index_;
std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
Status RegisterFunction(Id id, std::string arrow_function_name) override {
return parent_->CanRegisterFunction(id, arrow_function_name) &
ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name);
}

std::vector<const std::string*> function_name_ptrs_;
std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
function_name_to_index_;
} impl_;
const ExtensionIdRegistry* parent_;
};

struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
DefaultExtensionIdRegistry() {
struct TypeName {
std::shared_ptr<DataType> type;
util::string_view name;
};

// The type (variation) mappings listed below need to be kept in sync
// with the YAML at substrait/format/extension_types.yaml manually;
// see ARROW-15535.
for (TypeName e : {
TypeName{uint8(), "u8"},
TypeName{uint16(), "u16"},
TypeName{uint32(), "u32"},
TypeName{uint64(), "u64"},
TypeName{float16(), "fp16"},
}) {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
/*is_variation=*/true));
}

for (TypeName e : {
TypeName{null(), "null"},
TypeName{month_interval(), "interval_month"},
TypeName{day_time_interval(), "interval_day_milli"},
TypeName{month_day_nano_interval(), "interval_month_day_nano"},
}) {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
/*is_variation=*/false));
}

// TODO: this is just a placeholder right now. We'll need a YAML file for
// all functions (and prototypes) that Arrow provides that are relevant
// for Substrait, and include mappings for all of them here. See
// ARROW-15535.
for (util::string_view name : {
"add",
}) {
DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
}
}
};

} // namespace

ExtensionIdRegistry* default_extension_id_registry() {
static DefaultExtensionIdRegistry impl_;
return &impl_;
}

std::shared_ptr<ExtensionIdRegistry> nested_extension_id_registry(
const ExtensionIdRegistry* parent) {
return std::make_shared<NestedExtensionIdRegistryImpl>(parent);
}

} // namespace engine
} // namespace arrow
Loading