Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions cpp/src/arrow/dataset/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,35 @@ ValueDescr Expression::descr() const {
return CallNotNull(*this)->descr;
}

namespace {

std::string PrintDatum(const Datum& datum) {
if (datum.is_scalar()) {
switch (datum.type()->id()) {
case Type::STRING:
case Type::LARGE_STRING:
return '"' +
Escape(util::string_view(*datum.scalar_as<BaseBinaryScalar>().value)) +
'"';

case Type::BINARY:
case Type::FIXED_SIZE_BINARY:
case Type::LARGE_BINARY:
return '"' + datum.scalar_as<BaseBinaryScalar>().value->ToHexString() + '"';

default:
break;
}
return datum.scalar()->ToString();
}
return datum.ToString();
}

} // namespace

std::string Expression::ToString() const {
if (auto lit = literal()) {
if (lit->is_scalar()) {
switch (lit->type()->id()) {
case Type::STRING:
case Type::LARGE_STRING:
return '"' +
Escape(util::string_view(*lit->scalar_as<BaseBinaryScalar>().value)) +
'"';

case Type::BINARY:
case Type::FIXED_SIZE_BINARY:
case Type::LARGE_BINARY:
return '"' + lit->scalar_as<BaseBinaryScalar>().value->ToHexString() + '"';

default:
break;
}
return lit->scalar()->ToString();
}
return lit->ToString();
return PrintDatum(*lit);
}

if (auto ref = field_ref()) {
Expand Down Expand Up @@ -763,16 +771,7 @@ Status ExtractKnownFieldValuesImpl(
auto ref = call->arguments[0].field_ref();
auto lit = call->arguments[1].literal();

auto it_success = known_values->emplace(*ref, *lit);
if (it_success.second) continue;

// A value was already known for ref; check it
auto ref_lit = it_success.first;
if (*lit != ref_lit->second) {
return Status::Invalid("Conflicting guarantees: (", ref->ToString(),
" == ", lit->ToString(), ") vs (", ref->ToString(),
" == ", ref_lit->second.ToString());
}
known_values->emplace(*ref, *lit);
}

conjunction_members->erase(unconsumed_end, conjunction_members->end());
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/arrow/dataset/file_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,16 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
ARROW_ASSIGN_OR_RAISE(auto groups, write_options.partitioning->Partition(batch));
batch.reset(); // drop to hopefully conserve memory

if (groups.batches.size() > static_cast<size_t>(write_options.max_partitions)) {
return Status::Invalid("Fragment would be written into ", groups.batches.size(),
" partitions. This exceeds the maximum of ",
write_options.max_partitions);
}

std::unordered_set<WriteQueue*> need_flushed;
for (size_t i = 0; i < groups.batches.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(
auto partition_expression,
and_(std::move(groups.expressions[i]), fragment->partition_expression())
.Bind(*scanner->schema()));
auto partition_expression =
and_(std::move(groups.expressions[i]), fragment->partition_expression());
auto batch = std::move(groups.batches[i]);

ARROW_ASSIGN_OR_RAISE(auto part,
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/dataset/file_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions {
/// Partitioning used to generate fragment paths.
std::shared_ptr<Partitioning> partitioning;

/// Maximum number of partitions any batch may be written into, default is 1K.
int max_partitions = 1024;

/// Template string used to generate fragment basenames.
/// {i} will be replaced by an auto incremented integer.
std::string basename_template;
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/dataset/file_ipc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ TEST_F(TestIpcFileSystemDataset, WriteWithEmptyPartitioningSchema) {
TestWriteWithEmptyPartitioningSchema();
}

TEST_F(TestIpcFileSystemDataset, WriteExceedsMaxPartitions) {
write_options_.partitioning = std::make_shared<DirectoryPartitioning>(
SchemaFromColumnNames(source_schema_, {"model"}));

// require that no batch be grouped into more than 2 written batches:
write_options_.max_partitions = 2;

auto scanner = std::make_shared<Scanner>(dataset_, scan_options_, scan_context_);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("This exceeds the maximum"),
FileSystemDataset::Write(write_options_, scanner));
}

TEST_F(TestIpcFileFormat, OpenFailureWithRelevantError) {
std::shared_ptr<Buffer> buf = std::make_shared<Buffer>(util::string_view(""));
auto result = format_->Inspect(FileSource(buf));
Expand Down
109 changes: 75 additions & 34 deletions cpp/src/arrow/dataset/partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <vector>

#include "arrow/array/array_base.h"
#include "arrow/array/array_dict.h"
#include "arrow/array/array_nested.h"
#include "arrow/array/builder_dict.h"
#include "arrow/compute/api_scalar.h"
Expand Down Expand Up @@ -191,7 +192,7 @@ Result<Expression> KeyValuePartitioning::Parse(const std::string& path) const {
}

Result<std::string> KeyValuePartitioning::Format(const Expression& expr) const {
std::vector<Scalar*> values{static_cast<size_t>(schema_->num_fields()), nullptr};
ScalarVector values{static_cast<size_t>(schema_->num_fields()), nullptr};

ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr));
for (const auto& ref_value : known_values) {
Expand All @@ -202,15 +203,20 @@ Result<std::string> KeyValuePartitioning::Format(const Expression& expr) const {
ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_));
if (match.empty()) continue;

const auto& value = ref_value.second.scalar();
auto value = ref_value.second.scalar();

const auto& field = schema_->field(match[0]);
if (!value->type->Equals(field->type())) {
return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type,
") is invalid for ", field->ToString());
}

values[match[0]] = value.get();
if (value->type->id() == Type::DICTIONARY) {
ARROW_ASSIGN_OR_RAISE(
value, checked_cast<const DictionaryScalar&>(*value).GetEncodedValue());
}

values[match[0]] = std::move(value);
}

return FormatValues(values);
Expand All @@ -230,9 +236,9 @@ std::vector<KeyValuePartitioning::Key> DirectoryPartitioning::ParseKeys(
return keys;
}

inline util::optional<int> NextValid(const std::vector<Scalar*>& values, int first_null) {
inline util::optional<int> NextValid(const ScalarVector& values, int first_null) {
auto it = std::find_if(values.begin() + first_null + 1, values.end(),
[](Scalar* v) { return v != nullptr; });
[](const std::shared_ptr<Scalar>& v) { return v != nullptr; });

if (it == values.end()) {
return util::nullopt;
Expand All @@ -242,7 +248,7 @@ inline util::optional<int> NextValid(const std::vector<Scalar*>& values, int fir
}

Result<std::string> DirectoryPartitioning::FormatValues(
const std::vector<Scalar*>& values) const {
const ScalarVector& values) const {
std::vector<std::string> segments(static_cast<size_t>(schema_->num_fields()));

for (int i = 0; i < schema_->num_fields(); ++i) {
Expand Down Expand Up @@ -426,8 +432,7 @@ std::vector<KeyValuePartitioning::Key> HivePartitioning::ParseKeys(
return keys;
}

Result<std::string> HivePartitioning::FormatValues(
const std::vector<Scalar*>& values) const {
Result<std::string> HivePartitioning::FormatValues(const ScalarVector& values) const {
std::vector<std::string> segments(static_cast<size_t>(schema_->num_fields()));

for (int i = 0; i < schema_->num_fields(); ++i) {
Expand Down Expand Up @@ -532,19 +537,21 @@ Result<std::shared_ptr<Schema>> PartitioningOrFactory::GetOrInferSchema(

// Transform an array of counts to offsets which will divide a ListArray
// into an equal number of slices with corresponding lengths.
inline Result<std::shared_ptr<Array>> CountsToOffsets(
inline Result<std::shared_ptr<Buffer>> CountsToOffsets(
std::shared_ptr<Int64Array> counts) {
Int32Builder offset_builder;
TypedBufferBuilder<int32_t> offset_builder;
RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1));
offset_builder.UnsafeAppend(0);

int32_t current_offset = 0;
offset_builder.UnsafeAppend(current_offset);

for (int64_t i = 0; i < counts->length(); ++i) {
DCHECK_NE(counts->Value(i), 0);
auto next_offset = static_cast<int32_t>(offset_builder[i] + counts->Value(i));
offset_builder.UnsafeAppend(next_offset);
current_offset += static_cast<int32_t>(counts->Value(i));
offset_builder.UnsafeAppend(current_offset);
}

std::shared_ptr<Array> offsets;
std::shared_ptr<Buffer> offsets;
RETURN_NOT_OK(offset_builder.Finish(&offsets));
return offsets;
}
Expand Down Expand Up @@ -604,6 +611,12 @@ class StructDictionary {
RETURN_NOT_OK(builders[i].FinishInternal(&indices));

ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices));

if (fields[i]->type()->id() == Type::DICTIONARY) {
RETURN_NOT_OK(RestoreDictionaryEncoding(
checked_pointer_cast<DictionaryType>(fields[i]->type()), &column));
}

columns[i] = column.make_array();
}

Expand All @@ -612,27 +625,22 @@ class StructDictionary {

private:
Status AddOne(Datum column, std::shared_ptr<Int32Array>* fused_indices) {
ArrayData* encoded;
if (column.type()->id() != Type::DICTIONARY) {
ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column));
ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(std::move(column)));
}
encoded = column.mutable_array();

auto indices =
std::make_shared<Int32Array>(encoded->length, std::move(encoded->buffers[1]));

dictionaries_.push_back(MakeArray(std::move(encoded->dictionary)));
auto dictionary_size = static_cast<int32_t>(dictionaries_.back()->length());
auto dict_column = column.array_as<DictionaryArray>();
dictionaries_.push_back(dict_column->dictionary());
ARROW_ASSIGN_OR_RAISE(auto indices, compute::Cast(*dict_column->indices(), int32()));

if (*fused_indices == nullptr) {
*fused_indices = std::move(indices);
size_ = dictionary_size;
return Status::OK();
*fused_indices = checked_pointer_cast<Int32Array>(std::move(indices));
return IncreaseSize();
}

// It's useful to think about the case where each of dictionaries_ has size 10.
// In this case the decimal digit in the ones place is the code in dictionaries_[0],
// the tens place corresponds to dictionaries_[1], etc.
// the tens place corresponds to the code in dictionaries_[1], etc.
// The incumbent indices must be shifted to the hundreds place so as not to collide.
ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices,
compute::Multiply(indices, MakeScalar(size_)));
Expand All @@ -641,10 +649,7 @@ class StructDictionary {
compute::Add(new_fused_indices, *fused_indices));

*fused_indices = checked_pointer_cast<Int32Array>(new_fused_indices.make_array());

// XXX should probably cap this at 2**15 or so
ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_));
return Status::OK();
return IncreaseSize();
}

// expand a fused code into component dict codes, order is in order of addition
Expand All @@ -656,13 +661,48 @@ class StructDictionary {
}
}

int32_t size_;
Status RestoreDictionaryEncoding(std::shared_ptr<DictionaryType> expected_type,
Datum* column) {
DCHECK_NE(column->type()->id(), Type::DICTIONARY);
ARROW_ASSIGN_OR_RAISE(*column, compute::DictionaryEncode(std::move(*column)));

if (expected_type->index_type()->id() == Type::INT32) {
// dictionary_encode has already yielded the expected index_type
return Status::OK();
}

// cast the indices to the expected index type
auto dictionary = std::move(column->mutable_array()->dictionary);
column->mutable_array()->type = int32();

ARROW_ASSIGN_OR_RAISE(*column,
compute::Cast(std::move(*column), expected_type->index_type()));

column->mutable_array()->dictionary = std::move(dictionary);
column->mutable_array()->type = expected_type;
return Status::OK();
}

Status IncreaseSize() {
auto factor = static_cast<int32_t>(dictionaries_.back()->length());

if (internal::MultiplyWithOverflow(size_, factor, &size_)) {
return Status::CapacityError("Max groups exceeded");
}
return Status::OK();
}

int32_t size_ = 1;
ArrayVector dictionaries_;
};

Result<std::shared_ptr<StructArray>> MakeGroupings(const StructArray& by) {
if (by.num_fields() == 0) {
return Status::NotImplemented("Grouping with no criteria");
return Status::Invalid("Grouping with no criteria");
}

if (by.null_count() != 0) {
return Status::Invalid("Grouping with null criteria");
}

ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields()));
Expand All @@ -685,8 +725,9 @@ Result<std::shared_ptr<StructArray>> MakeGroupings(const StructArray& by) {
checked_pointer_cast<Int64Array>(fused_counts_and_values->GetFieldByName("counts"));
ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts)));

ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices,
ListArray::FromArrays(*offsets, *sort_indices));
auto grouped_sort_indices =
std::make_shared<ListArray>(list(sort_indices->type()), unique_rows->length(),
std::move(offsets), std::move(sort_indices));

return StructArray::Make(
ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)},
Expand Down
Loading