Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 128 additions & 18 deletions cpp/src/arrow/dataset/partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
#include "arrow/util/logging.h"
#include "arrow/util/make_unique.h"
#include "arrow/util/string_view.h"
#include "arrow/util/uri.h"
#include "arrow/util/utf8.h"

namespace arrow {

Expand All @@ -46,6 +48,18 @@ using util::string_view;

namespace dataset {

namespace {
/// Apply UriUnescape, then ensure the results are valid UTF-8.
Result<std::string> SafeUriUnescape(util::string_view encoded) {
auto decoded = internal::UriUnescape(encoded);
if (!util::ValidateUTF8(decoded)) {
return Status::Invalid("Partition segment was not valid UTF-8 after URL decoding: ",
encoded);
}
return decoded;
}
} // namespace

std::shared_ptr<Partitioning> Partitioning::Default() {
class DefaultPartitioning : public Partitioning {
public:
Expand Down Expand Up @@ -158,6 +172,21 @@ Result<Partitioning::PartitionedBatches> KeyValuePartitioning::Partition(
return out;
}

std::ostream& operator<<(std::ostream& os, SegmentEncoding segment_encoding) {
switch (segment_encoding) {
case SegmentEncoding::None:
os << "SegmentEncoding::None";
break;
case SegmentEncoding::Uri:
os << "SegmentEncoding::Uri";
break;
default:
os << "(invalid SegmentEncoding " << static_cast<int8_t>(segment_encoding) << ")";
break;
}
return os;
}

Result<compute::Expression> KeyValuePartitioning::ConvertKey(const Key& key) const {
ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_));
if (match.empty()) {
Expand Down Expand Up @@ -209,7 +238,8 @@ Result<compute::Expression> KeyValuePartitioning::ConvertKey(const Key& key) con
Result<compute::Expression> KeyValuePartitioning::Parse(const std::string& path) const {
std::vector<compute::Expression> expressions;

for (const Key& key : ParseKeys(path)) {
ARROW_ASSIGN_OR_RAISE(auto parsed, ParseKeys(path));
for (const Key& key : parsed) {
ARROW_ASSIGN_OR_RAISE(auto expr, ConvertKey(key));
if (expr == compute::literal(true)) continue;
expressions.push_back(std::move(expr));
Expand Down Expand Up @@ -259,15 +289,38 @@ Result<std::string> KeyValuePartitioning::Format(const compute::Expression& expr
return FormatValues(values);
}

std::vector<KeyValuePartitioning::Key> DirectoryPartitioning::ParseKeys(
DirectoryPartitioning::DirectoryPartitioning(std::shared_ptr<Schema> schema,
ArrayVector dictionaries,
KeyValuePartitioningOptions options)
: KeyValuePartitioning(std::move(schema), std::move(dictionaries), options) {
util::InitializeUTF8();
}

Result<std::vector<KeyValuePartitioning::Key>> DirectoryPartitioning::ParseKeys(
const std::string& path) const {
std::vector<Key> keys;

int i = 0;
for (auto&& segment : fs::internal::SplitAbstractPath(path)) {
if (i >= schema_->num_fields()) break;

keys.push_back({schema_->field(i++)->name(), std::move(segment)});
switch (options_.segment_encoding) {
case SegmentEncoding::None: {
if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(segment))) {
return Status::Invalid("Partition segment was not valid UTF-8: ", segment);
}
keys.push_back({schema_->field(i++)->name(), std::move(segment)});
break;
}
case SegmentEncoding::Uri: {
ARROW_ASSIGN_OR_RAISE(auto decoded, SafeUriUnescape(segment));
keys.push_back({schema_->field(i++)->name(), std::move(decoded)});
break;
}
default:
return Status::NotImplemented("Unknown segment encoding: ",
options_.segment_encoding);
}
}

return keys;
Expand Down Expand Up @@ -308,6 +361,20 @@ Result<std::string> DirectoryPartitioning::FormatValues(
return fs::internal::JoinAbstractPath(std::move(segments));
}

KeyValuePartitioningOptions PartitioningFactoryOptions::AsPartitioningOptions() const {
KeyValuePartitioningOptions options;
options.segment_encoding = segment_encoding;
return options;
}

HivePartitioningOptions HivePartitioningFactoryOptions::AsHivePartitioningOptions()
const {
HivePartitioningOptions options;
options.segment_encoding = segment_encoding;
options.null_fallback = null_fallback;
return options;
}

namespace {
class KeyValuePartitioningFactory : public PartitioningFactory {
protected:
Expand Down Expand Up @@ -430,6 +497,7 @@ class DirectoryPartitioningFactory : public KeyValuePartitioningFactory {
PartitioningFactoryOptions options)
: KeyValuePartitioningFactory(options), field_names_(std::move(field_names)) {
Reset();
util::InitializeUTF8();
}

std::string type_name() const override { return "schema"; }
Expand All @@ -441,7 +509,23 @@ class DirectoryPartitioningFactory : public KeyValuePartitioningFactory {
for (auto&& segment : fs::internal::SplitAbstractPath(path)) {
if (field_index == field_names_.size()) break;

RETURN_NOT_OK(InsertRepr(static_cast<int>(field_index++), segment));
switch (options_.segment_encoding) {
case SegmentEncoding::None: {
if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(segment))) {
return Status::Invalid("Partition segment was not valid UTF-8: ", segment);
}
RETURN_NOT_OK(InsertRepr(static_cast<int>(field_index++), segment));
break;
}
case SegmentEncoding::Uri: {
ARROW_ASSIGN_OR_RAISE(auto decoded, SafeUriUnescape(segment));
RETURN_NOT_OK(InsertRepr(static_cast<int>(field_index++), decoded));
break;
}
default:
return Status::NotImplemented("Unknown segment encoding: ",
options_.segment_encoding);
}
}
}

Expand All @@ -458,7 +542,8 @@ class DirectoryPartitioningFactory : public KeyValuePartitioningFactory {
// drop fields which aren't in field_names_
auto out_schema = SchemaFromColumnNames(schema, field_names_);

return std::make_shared<DirectoryPartitioning>(std::move(out_schema), dictionaries_);
return std::make_shared<DirectoryPartitioning>(std::move(out_schema), dictionaries_,
options_.AsPartitioningOptions());
}

private:
Expand All @@ -481,28 +566,50 @@ std::shared_ptr<PartitioningFactory> DirectoryPartitioning::MakeFactory(
new DirectoryPartitioningFactory(std::move(field_names), options));
}

util::optional<KeyValuePartitioning::Key> HivePartitioning::ParseKey(
const std::string& segment, const std::string& null_fallback) {
Result<util::optional<KeyValuePartitioning::Key>> HivePartitioning::ParseKey(
const std::string& segment, const HivePartitioningOptions& options) {
auto name_end = string_view(segment).find_first_of('=');
// Not round-trippable
if (name_end == string_view::npos) {
return util::nullopt;
}

// Static method, so we have no better place for it
util::InitializeUTF8();

auto name = segment.substr(0, name_end);
auto value = segment.substr(name_end + 1);
if (value == null_fallback) {
return Key{name, util::nullopt};
std::string value;
switch (options.segment_encoding) {
case SegmentEncoding::None: {
value = segment.substr(name_end + 1);
if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(value))) {
return Status::Invalid("Partition segment was not valid UTF-8: ", value);
}
break;
}
case SegmentEncoding::Uri: {
auto raw_value = util::string_view(segment).substr(name_end + 1);
ARROW_ASSIGN_OR_RAISE(value, SafeUriUnescape(raw_value));
break;
}
default:
return Status::NotImplemented("Unknown segment encoding: ",
options.segment_encoding);
}

if (value == options.null_fallback) {
return Key{std::move(name), util::nullopt};
}
return Key{name, value};
return Key{std::move(name), std::move(value)};
}

std::vector<KeyValuePartitioning::Key> HivePartitioning::ParseKeys(
Result<std::vector<KeyValuePartitioning::Key>> HivePartitioning::ParseKeys(
const std::string& path) const {
std::vector<Key> keys;

for (const auto& segment : fs::internal::SplitAbstractPath(path)) {
if (auto key = ParseKey(segment, null_fallback_)) {
ARROW_ASSIGN_OR_RAISE(auto maybe_key, ParseKey(segment, hive_options_));
if (auto key = maybe_key) {
keys.push_back(std::move(*key));
}
}
Expand All @@ -521,7 +628,7 @@ Result<std::string> HivePartitioning::FormatValues(const ScalarVector& values) c
} else if (!values[i]->is_valid) {
// If no key is available just provide a placeholder segment to maintain the
// field_index <-> path nesting relation
segments[i] = name + "=" + null_fallback_;
segments[i] = name + "=" + hive_options_.null_fallback;
} else {
segments[i] = name + "=" + values[i]->ToString();
}
Expand All @@ -533,15 +640,18 @@ Result<std::string> HivePartitioning::FormatValues(const ScalarVector& values) c
class HivePartitioningFactory : public KeyValuePartitioningFactory {
public:
explicit HivePartitioningFactory(HivePartitioningFactoryOptions options)
: KeyValuePartitioningFactory(options), null_fallback_(options.null_fallback) {}
: KeyValuePartitioningFactory(options), options_(std::move(options)) {}

std::string type_name() const override { return "hive"; }

Result<std::shared_ptr<Schema>> Inspect(
const std::vector<std::string>& paths) override {
auto options = options_.AsHivePartitioningOptions();
for (auto path : paths) {
for (auto&& segment : fs::internal::SplitAbstractPath(path)) {
if (auto key = HivePartitioning::ParseKey(segment, null_fallback_)) {
ARROW_ASSIGN_OR_RAISE(auto maybe_key,
HivePartitioning::ParseKey(segment, options));
if (auto key = maybe_key) {
RETURN_NOT_OK(InsertRepr(key->name, key->value));
}
}
Expand All @@ -565,12 +675,12 @@ class HivePartitioningFactory : public KeyValuePartitioningFactory {
auto out_schema = SchemaFromColumnNames(schema, field_names_);

return std::make_shared<HivePartitioning>(std::move(out_schema), dictionaries_,
null_fallback_);
options_.AsHivePartitioningOptions());
}
}

private:
const std::string null_fallback_;
const HivePartitioningFactoryOptions options_;
std::vector<std::string> field_names_;
};

Expand Down
Loading